{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ssms\n",
    "import lanfactory\n",
    "import os\n",
    "import numpy as np\n",
    "from copy import deepcopy\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL = \"angle\"\n",
    "RUN_SIMS = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the generator config (for MLP LANs)\n",
    "generator_config = deepcopy(ssms.config.data_generator_config[\"lan\"])\n",
    "# Specify generative model (one from the list of included models mentioned above)\n",
    "generator_config[\"dgp_list\"] = MODEL\n",
    "# Specify number of parameter sets to simulate\n",
    "generator_config[\"n_parameter_sets\"] = 256\n",
    "# Specify how many samples a simulation run should entail\n",
    "generator_config[\"n_samples\"] = 2000\n",
    "# Specify folder in which to save generated data\n",
    "generator_config[\"output_folder\"] = \"data/lan_mlp/\"\n",
    "\n",
    "# Make model config dict\n",
    "model_config = ssms.config.model_config[MODEL]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'name': 'angle',\n",
       " 'params': ['v', 'a', 'z', 't', 'theta'],\n",
       " 'param_bounds': [[-3.0, 0.3, 0.1, 0.001, -0.1], [3.0, 3.0, 0.9, 2.0, 1.3]],\n",
       " 'boundary': <function ssms.basic_simulators.boundary_functions.angle(t=1, theta=1)>,\n",
       " 'n_params': 5,\n",
       " 'default_params': [0.0, 1.0, 0.5, 0.001, 0.0],\n",
       " 'hddm_include': ['z', 'theta'],\n",
       " 'nchoices': 2}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator_config[\"output_folder\"] = (\n",
    "    \"data/lan_mlp/\"\n",
    "    + generator_config[\"dgp_list\"]\n",
    "    + \"/\"\n",
    "    + str(generator_config[\"n_samples\"])\n",
    "    + \"_\"\n",
    "    + str(generator_config[\"n_training_samples_by_parameter_set\"])\n",
    "    + \"/\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "if RUN_SIMS:\n",
    "    n_datafiles = 20\n",
    "    for i in range(n_datafiles):\n",
    "        print(\"Datafile: \", i)\n",
    "        my_dataset_generator = ssms.dataset_generators.lan_mlp.data_generator(\n",
    "            generator_config=generator_config, model_config=model_config\n",
    "        )\n",
    "        training_data = my_dataset_generator.generate_data_training_uniform(save=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "my_data = pickle.load(\n",
    "    open(\n",
    "        \"data/lan_mlp/angle/2000_1000/training_data_8d617928653811eebb25a0423f3e9be0.pickle\",\n",
    "        \"rb\",\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ -1.68907  ,  -1.2301161,  -1.247614 , ..., -66.77497  ,\n",
       "       -66.77497  , -66.77497  ], dtype=float32)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "my_data[\"labels\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Network config: \n",
      "{'layer_sizes': [100, 100, 1], 'activations': ['tanh', 'tanh', 'linear'], 'train_output_type': 'logprob'}\n",
      "Train config: \n",
      "{'cpu_batch_size': 128, 'gpu_batch_size': 256, 'n_epochs': 5, 'optimizer': 'adam', 'learning_rate': 0.002, 'lr_scheduler': 'reduce_on_plateau', 'lr_scheduler_params': {}, 'weight_decay': 0.0, 'loss': 'huber', 'save_history': True}\n"
     ]
    }
   ],
   "source": [
    "network_config = lanfactory.config.network_configs.network_config_mlp\n",
    "\n",
    "print(\"Network config: \")\n",
    "print(network_config)\n",
    "\n",
    "train_config = lanfactory.config.network_configs.train_config_mlp\n",
    "\n",
    "print(\"Train config: \")\n",
    "print(train_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_config[\"cpu_batch_size\"] = 128\n",
    "train_config[\"gpu_batch_size\"] = 2048\n",
    "train_config[\"n_epochs\"] = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_ = \"data/lan_mlp/\" + MODEL + \"/2000_1000/\"\n",
    "file_list_ = [folder_ + file_ for file_ in os.listdir(folder_)]\n",
    "\n",
    "# Training dataset\n",
    "torch_training_dataset = lanfactory.trainers.DatasetTorch(\n",
    "    file_ids=file_list_,\n",
    "    batch_size=(\n",
    "        train_config[\"gpu_batch_size\"]\n",
    "        if torch.cuda.is_available()\n",
    "        else train_config[\"cpu_batch_size\"]\n",
    "    ),\n",
    "    label_lower_bound=np.log(1e-10),\n",
    "    features_key=\"data\",\n",
    "    label_key=\"labels\",\n",
    "    out_framework=\"torch\",\n",
    ")\n",
    "\n",
    "torch_training_dataloader = torch.utils.data.DataLoader(\n",
    "    torch_training_dataset,\n",
    "    shuffle=True,\n",
    "    batch_size=None,\n",
    "    num_workers=1,\n",
    "    pin_memory=True,\n",
    ")\n",
    "\n",
    "# Validation dataset\n",
    "torch_validation_dataset = lanfactory.trainers.DatasetTorch(\n",
    "    file_ids=file_list_,\n",
    "    batch_size=(\n",
    "        train_config[\"gpu_batch_size\"]\n",
    "        if torch.cuda.is_available()\n",
    "        else train_config[\"cpu_batch_size\"]\n",
    "    ),\n",
    "    label_lower_bound=np.log(1e-10),\n",
    "    features_key=\"data\",\n",
    "    label_key=\"labels\",\n",
    "    out_framework=\"torch\",\n",
    ")\n",
    "\n",
    "torch_validation_dataloader = torch.utils.data.DataLoader(\n",
    "    torch_validation_dataset,\n",
    "    shuffle=True,\n",
    "    batch_size=None,\n",
    "    num_workers=1,\n",
    "    pin_memory=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.2859,  1.5254,  0.8971,  ...,  0.8694,  1.4443,  1.0000],\n",
      "        [ 2.3546,  2.7225,  0.1075,  ...,  0.6222,  1.7985,  1.0000],\n",
      "        [ 2.5058,  2.0068,  0.3409,  ...,  0.6053,  1.3691,  1.0000],\n",
      "        ...,\n",
      "        [ 2.7323,  1.6003,  0.4028,  ...,  0.6330,  2.3316,  1.0000],\n",
      "        [ 2.2446,  2.0103,  0.3365,  ...,  0.8664,  2.3509,  1.0000],\n",
      "        [-0.2184,  1.0581,  0.4081,  ...,  0.0791,  6.5410,  1.0000]])\n",
      "tensor([[-1.3090],\n",
      "        [-0.0718],\n",
      "        [ 0.3968],\n",
      "        ...,\n",
      "        [-1.5956],\n",
      "        [ 0.5722],\n",
      "        [-8.6107]])\n"
     ]
    }
   ],
   "source": [
    "cnt = 0\n",
    "for xb, yb in torch_training_dataloader:\n",
    "    print(xb)\n",
    "    print(yb)\n",
    "    cnt += 1\n",
    "    if cnt > 0:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tanh\n",
      "tanh\n",
      "linear\n",
      "Found folder:  data\n",
      "Moving on...\n",
      "Found folder:  data/torch_models\n",
      "Moving on...\n",
      "Found folder:  data/torch_models/angle\n",
      "Moving on...\n",
      "Saved network config\n",
      "Saved train config\n"
     ]
    }
   ],
   "source": [
    "# LOAD NETWORK\n",
    "net = lanfactory.trainers.TorchMLP(\n",
    "    network_config=deepcopy(network_config),\n",
    "    input_shape=torch_training_dataset.input_dim,\n",
    ")\n",
    "\n",
    "# SAVE CONFIGS\n",
    "lanfactory.utils.save_configs(\n",
    "    model_id=MODEL + \"_torch\",\n",
    "    save_folder=\"data/torch_models/\" + MODEL + \"/\",\n",
    "    network_config=network_config,\n",
    "    train_config=train_config,\n",
    "    allow_abs_path_folder_generation=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Torch Device:  cuda\n",
      "train_config is passed as string:  data/torch_models/angle/angle_torch_train_config.pickle\n",
      "Trying to load string as path to pickle file: \n",
      "{'cpu_batch_size': 128, 'gpu_batch_size': 2048, 'n_epochs': 20, 'optimizer': 'adam', 'learning_rate': 0.002, 'lr_scheduler': 'reduce_on_plateau', 'lr_scheduler_params': {}, 'weight_decay': 0.0, 'loss': 'huber', 'save_history': True}\n"
     ]
    }
   ],
   "source": [
    "model_trainer = lanfactory.trainers.ModelTrainerTorchMLP(\n",
    "    train_config=\"data/torch_models/angle/angle_torch_train_config.pickle\",\n",
    "    model=net,\n",
    "    train_dl=torch_training_dataloader,\n",
    "    valid_dl=torch_validation_dataloader,\n",
    "    allow_abs_path_folder_generation=False,\n",
    "    pin_memory=True,\n",
    "    seed=None,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found folder:  data\n",
      "Moving on...\n",
      "Found folder:  data/trained_model\n",
      "Moving on...\n",
      "Found folder:  data/trained_model/torch\n",
      "Moving on...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mafengler\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.15.12"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>/oscar/data/frankmj/afengler/proj_lanfactory/LANfactory/notebooks/test_notebooks/wandb/run-20231007_160318-a6ab4cv5</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/afengler/test_run_notebook/runs/a6ab4cv5' target=\"_blank\">wd_0.0_optim_adam_test_run_notebook</a></strong> to <a href='https://wandb.ai/afengler/test_run_notebook' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/afengler/test_run_notebook' target=\"_blank\">https://wandb.ai/afengler/test_run_notebook</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/afengler/test_run_notebook/runs/a6ab4cv5' target=\"_blank\">https://wandb.ai/afengler/test_run_notebook/runs/a6ab4cv5</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Succefully initialized wandb!\n",
      "epoch: 0 / 20, batch: 0 / 2440, batch_loss: 4.694940567016602\n",
      "epoch: 0 / 20, batch: 100 / 2440, batch_loss: 2.23306941986084\n",
      "epoch: 0 / 20, batch: 200 / 2440, batch_loss: 0.9773080348968506\n",
      "epoch: 0 / 20, batch: 300 / 2440, batch_loss: 0.7204380035400391\n",
      "epoch: 0 / 20, batch: 400 / 2440, batch_loss: 0.5427260398864746\n",
      "epoch: 0 / 20, batch: 500 / 2440, batch_loss: 0.5137326717376709\n",
      "epoch: 0 / 20, batch: 600 / 2440, batch_loss: 0.456951379776001\n",
      "epoch: 0 / 20, batch: 700 / 2440, batch_loss: 0.3506656289100647\n",
      "epoch: 0 / 20, batch: 800 / 2440, batch_loss: 0.35916250944137573\n",
      "epoch: 0 / 20, batch: 900 / 2440, batch_loss: 0.3065577745437622\n",
      "epoch: 0 / 20, batch: 1000 / 2440, batch_loss: 0.26735538244247437\n",
      "epoch: 0 / 20, batch: 1100 / 2440, batch_loss: 0.27880311012268066\n",
      "epoch: 0 / 20, batch: 1200 / 2440, batch_loss: 0.26584047079086304\n",
      "epoch: 0 / 20, batch: 1300 / 2440, batch_loss: 0.22467461228370667\n",
      "epoch: 0 / 20, batch: 1400 / 2440, batch_loss: 0.23233294486999512\n",
      "epoch: 0 / 20, batch: 1500 / 2440, batch_loss: 0.21492521464824677\n",
      "epoch: 0 / 20, batch: 1600 / 2440, batch_loss: 0.277679443359375\n",
      "epoch: 0 / 20, batch: 1700 / 2440, batch_loss: 0.19611068069934845\n",
      "epoch: 0 / 20, batch: 1800 / 2440, batch_loss: 0.21548068523406982\n",
      "epoch: 0 / 20, batch: 1900 / 2440, batch_loss: 0.22150284051895142\n",
      "epoch: 0 / 20, batch: 2000 / 2440, batch_loss: 0.21193864941596985\n",
      "epoch: 0 / 20, batch: 2100 / 2440, batch_loss: 0.2196582853794098\n",
      "epoch: 0 / 20, batch: 2200 / 2440, batch_loss: 0.17817966639995575\n",
      "epoch: 0 / 20, batch: 2300 / 2440, batch_loss: 0.17136915028095245\n",
      "epoch: 0 / 20, batch: 2400 / 2440, batch_loss: 0.1781916469335556\n",
      "Epoch took 0 / 20,  took 3.945970058441162 seconds\n",
      "epoch 0 / 20, validation_loss: 0.1762\n",
      "epoch: 1 / 20, batch: 0 / 2440, batch_loss: 0.1826338768005371\n",
      "epoch: 1 / 20, batch: 100 / 2440, batch_loss: 0.1777007281780243\n",
      "epoch: 1 / 20, batch: 200 / 2440, batch_loss: 0.18198856711387634\n",
      "epoch: 1 / 20, batch: 300 / 2440, batch_loss: 0.15861915051937103\n",
      "epoch: 1 / 20, batch: 400 / 2440, batch_loss: 0.13471561670303345\n",
      "epoch: 1 / 20, batch: 500 / 2440, batch_loss: 0.13207107782363892\n",
      "epoch: 1 / 20, batch: 600 / 2440, batch_loss: 0.15609858930110931\n",
      "epoch: 1 / 20, batch: 700 / 2440, batch_loss: 0.1497088223695755\n",
      "epoch: 1 / 20, batch: 800 / 2440, batch_loss: 0.1333540976047516\n",
      "epoch: 1 / 20, batch: 900 / 2440, batch_loss: 0.1420605182647705\n",
      "epoch: 1 / 20, batch: 1000 / 2440, batch_loss: 0.14186277985572815\n",
      "epoch: 1 / 20, batch: 1100 / 2440, batch_loss: 0.13303761184215546\n",
      "epoch: 1 / 20, batch: 1200 / 2440, batch_loss: 0.128230020403862\n",
      "epoch: 1 / 20, batch: 1300 / 2440, batch_loss: 0.09170198440551758\n",
      "epoch: 1 / 20, batch: 1400 / 2440, batch_loss: 0.13452035188674927\n",
      "epoch: 1 / 20, batch: 1500 / 2440, batch_loss: 0.11577793210744858\n",
      "epoch: 1 / 20, batch: 1600 / 2440, batch_loss: 0.13593247532844543\n",
      "epoch: 1 / 20, batch: 1700 / 2440, batch_loss: 0.10786199569702148\n",
      "epoch: 1 / 20, batch: 1800 / 2440, batch_loss: 0.13385529816150665\n",
      "epoch: 1 / 20, batch: 1900 / 2440, batch_loss: 0.10916884243488312\n",
      "epoch: 1 / 20, batch: 2000 / 2440, batch_loss: 0.07888741046190262\n",
      "epoch: 1 / 20, batch: 2100 / 2440, batch_loss: 0.12232646346092224\n",
      "epoch: 1 / 20, batch: 2200 / 2440, batch_loss: 0.11699306964874268\n",
      "epoch: 1 / 20, batch: 2300 / 2440, batch_loss: 0.09540767967700958\n",
      "epoch: 1 / 20, batch: 2400 / 2440, batch_loss: 0.08482584357261658\n",
      "Epoch took 1 / 20,  took 3.7771692276000977 seconds\n",
      "epoch 1 / 20, validation_loss: 0.1073\n",
      "epoch: 2 / 20, batch: 0 / 2440, batch_loss: 0.11805432289838791\n",
      "epoch: 2 / 20, batch: 100 / 2440, batch_loss: 0.09726420044898987\n",
      "epoch: 2 / 20, batch: 200 / 2440, batch_loss: 0.11088905483484268\n",
      "epoch: 2 / 20, batch: 300 / 2440, batch_loss: 0.08113141357898712\n",
      "epoch: 2 / 20, batch: 400 / 2440, batch_loss: 0.09579584002494812\n",
      "epoch: 2 / 20, batch: 500 / 2440, batch_loss: 0.11542697250843048\n",
      "epoch: 2 / 20, batch: 600 / 2440, batch_loss: 0.06346652656793594\n",
      "epoch: 2 / 20, batch: 700 / 2440, batch_loss: 0.11783280968666077\n",
      "epoch: 2 / 20, batch: 800 / 2440, batch_loss: 0.08076564967632294\n",
      "epoch: 2 / 20, batch: 900 / 2440, batch_loss: 0.07901154458522797\n",
      "epoch: 2 / 20, batch: 1000 / 2440, batch_loss: 0.05554930865764618\n",
      "epoch: 2 / 20, batch: 1100 / 2440, batch_loss: 0.06241322308778763\n",
      "epoch: 2 / 20, batch: 1200 / 2440, batch_loss: 0.11794434487819672\n",
      "epoch: 2 / 20, batch: 1300 / 2440, batch_loss: 0.09650713205337524\n",
      "epoch: 2 / 20, batch: 1400 / 2440, batch_loss: 0.05462045222520828\n",
      "epoch: 2 / 20, batch: 1500 / 2440, batch_loss: 0.06486432254314423\n",
      "epoch: 2 / 20, batch: 1600 / 2440, batch_loss: 0.08199632167816162\n",
      "epoch: 2 / 20, batch: 1700 / 2440, batch_loss: 0.09839476644992828\n",
      "epoch: 2 / 20, batch: 1800 / 2440, batch_loss: 0.07764306664466858\n",
      "epoch: 2 / 20, batch: 1900 / 2440, batch_loss: 0.07694990932941437\n",
      "epoch: 2 / 20, batch: 2000 / 2440, batch_loss: 0.06325538456439972\n",
      "epoch: 2 / 20, batch: 2100 / 2440, batch_loss: 0.10360311716794968\n",
      "epoch: 2 / 20, batch: 2200 / 2440, batch_loss: 0.0851779580116272\n",
      "epoch: 2 / 20, batch: 2300 / 2440, batch_loss: 0.07072775065898895\n",
      "epoch: 2 / 20, batch: 2400 / 2440, batch_loss: 0.04903118312358856\n",
      "Epoch took 2 / 20,  took 3.8361449241638184 seconds\n",
      "epoch 2 / 20, validation_loss: 0.07392\n",
      "epoch: 3 / 20, batch: 0 / 2440, batch_loss: 0.07679135352373123\n",
      "epoch: 3 / 20, batch: 100 / 2440, batch_loss: 0.051657307893037796\n",
      "epoch: 3 / 20, batch: 200 / 2440, batch_loss: 0.04805620014667511\n",
      "epoch: 3 / 20, batch: 300 / 2440, batch_loss: 0.057625770568847656\n",
      "epoch: 3 / 20, batch: 400 / 2440, batch_loss: 0.07225365936756134\n",
      "epoch: 3 / 20, batch: 500 / 2440, batch_loss: 0.07614465057849884\n",
      "epoch: 3 / 20, batch: 600 / 2440, batch_loss: 0.071581169962883\n",
      "epoch: 3 / 20, batch: 700 / 2440, batch_loss: 0.07754616439342499\n",
      "epoch: 3 / 20, batch: 800 / 2440, batch_loss: 0.06670297682285309\n",
      "epoch: 3 / 20, batch: 900 / 2440, batch_loss: 0.05127939581871033\n",
      "epoch: 3 / 20, batch: 1000 / 2440, batch_loss: 0.06461440026760101\n",
      "epoch: 3 / 20, batch: 1100 / 2440, batch_loss: 0.045543015003204346\n",
      "epoch: 3 / 20, batch: 1200 / 2440, batch_loss: 0.080250084400177\n",
      "epoch: 3 / 20, batch: 1300 / 2440, batch_loss: 0.0473194494843483\n",
      "epoch: 3 / 20, batch: 1400 / 2440, batch_loss: 0.05145417898893356\n",
      "epoch: 3 / 20, batch: 1500 / 2440, batch_loss: 0.049799226224422455\n",
      "epoch: 3 / 20, batch: 1600 / 2440, batch_loss: 0.04655963554978371\n",
      "epoch: 3 / 20, batch: 1700 / 2440, batch_loss: 0.05118640139698982\n",
      "epoch: 3 / 20, batch: 1800 / 2440, batch_loss: 0.041302479803562164\n",
      "epoch: 3 / 20, batch: 1900 / 2440, batch_loss: 0.05498439073562622\n",
      "epoch: 3 / 20, batch: 2000 / 2440, batch_loss: 0.03278551250696182\n",
      "epoch: 3 / 20, batch: 2100 / 2440, batch_loss: 0.07184702157974243\n",
      "epoch: 3 / 20, batch: 2200 / 2440, batch_loss: 0.04734371602535248\n",
      "epoch: 3 / 20, batch: 2300 / 2440, batch_loss: 0.03805344179272652\n",
      "epoch: 3 / 20, batch: 2400 / 2440, batch_loss: 0.03615497797727585\n",
      "Epoch took 3 / 20,  took 3.759489059448242 seconds\n",
      "epoch 3 / 20, validation_loss: 0.06279\n",
      "epoch: 4 / 20, batch: 0 / 2440, batch_loss: 0.07347521185874939\n",
      "epoch: 4 / 20, batch: 100 / 2440, batch_loss: 0.036819640547037125\n",
      "epoch: 4 / 20, batch: 200 / 2440, batch_loss: 0.06594939529895782\n",
      "epoch: 4 / 20, batch: 300 / 2440, batch_loss: 0.07462650537490845\n",
      "epoch: 4 / 20, batch: 400 / 2440, batch_loss: 0.03951900079846382\n",
      "epoch: 4 / 20, batch: 500 / 2440, batch_loss: 0.05384526401758194\n",
      "epoch: 4 / 20, batch: 600 / 2440, batch_loss: 0.03164060413837433\n",
      "epoch: 4 / 20, batch: 700 / 2440, batch_loss: 0.03748104348778725\n",
      "epoch: 4 / 20, batch: 800 / 2440, batch_loss: 0.03482696786522865\n",
      "epoch: 4 / 20, batch: 900 / 2440, batch_loss: 0.05877228453755379\n",
      "epoch: 4 / 20, batch: 1000 / 2440, batch_loss: 0.0540461465716362\n",
      "epoch: 4 / 20, batch: 1100 / 2440, batch_loss: 0.07199061661958694\n",
      "epoch: 4 / 20, batch: 1200 / 2440, batch_loss: 0.03799916058778763\n",
      "epoch: 4 / 20, batch: 1300 / 2440, batch_loss: 0.04132772237062454\n",
      "epoch: 4 / 20, batch: 1400 / 2440, batch_loss: 0.061114098876714706\n",
      "epoch: 4 / 20, batch: 1500 / 2440, batch_loss: 0.033305007964372635\n",
      "epoch: 4 / 20, batch: 1600 / 2440, batch_loss: 0.051266543567180634\n",
      "epoch: 4 / 20, batch: 1700 / 2440, batch_loss: 0.050168298184871674\n",
      "epoch: 4 / 20, batch: 1800 / 2440, batch_loss: 0.06299306452274323\n",
      "epoch: 4 / 20, batch: 1900 / 2440, batch_loss: 0.04107462614774704\n",
      "epoch: 4 / 20, batch: 2000 / 2440, batch_loss: 0.04401233047246933\n",
      "epoch: 4 / 20, batch: 2100 / 2440, batch_loss: 0.034776754677295685\n",
      "epoch: 4 / 20, batch: 2200 / 2440, batch_loss: 0.034614548087120056\n",
      "epoch: 4 / 20, batch: 2300 / 2440, batch_loss: 0.035181351006031036\n",
      "epoch: 4 / 20, batch: 2400 / 2440, batch_loss: 0.03762952983379364\n",
      "Epoch took 4 / 20,  took 3.7798702716827393 seconds\n",
      "epoch 4 / 20, validation_loss: 0.0537\n",
      "epoch: 5 / 20, batch: 0 / 2440, batch_loss: 0.058970287442207336\n",
      "epoch: 5 / 20, batch: 100 / 2440, batch_loss: 0.03451056778430939\n",
      "epoch: 5 / 20, batch: 200 / 2440, batch_loss: 0.043514806777238846\n",
      "epoch: 5 / 20, batch: 300 / 2440, batch_loss: 0.03818107023835182\n",
      "epoch: 5 / 20, batch: 400 / 2440, batch_loss: 0.0603254996240139\n",
      "epoch: 5 / 20, batch: 500 / 2440, batch_loss: 0.05138325318694115\n",
      "epoch: 5 / 20, batch: 600 / 2440, batch_loss: 0.03114839643239975\n",
      "epoch: 5 / 20, batch: 700 / 2440, batch_loss: 0.038076482713222504\n",
      "epoch: 5 / 20, batch: 800 / 2440, batch_loss: 0.03831346333026886\n",
      "epoch: 5 / 20, batch: 900 / 2440, batch_loss: 0.05197165161371231\n",
      "epoch: 5 / 20, batch: 1000 / 2440, batch_loss: 0.043660685420036316\n",
      "epoch: 5 / 20, batch: 1100 / 2440, batch_loss: 0.04596983641386032\n",
      "epoch: 5 / 20, batch: 1200 / 2440, batch_loss: 0.03947557881474495\n",
      "epoch: 5 / 20, batch: 1300 / 2440, batch_loss: 0.036666207015514374\n",
      "epoch: 5 / 20, batch: 1400 / 2440, batch_loss: 0.064993716776371\n",
      "epoch: 5 / 20, batch: 1500 / 2440, batch_loss: 0.027076376602053642\n",
      "epoch: 5 / 20, batch: 1600 / 2440, batch_loss: 0.04253571853041649\n",
      "epoch: 5 / 20, batch: 1700 / 2440, batch_loss: 0.04271668195724487\n",
      "epoch: 5 / 20, batch: 1800 / 2440, batch_loss: 0.0677276998758316\n",
      "epoch: 5 / 20, batch: 1900 / 2440, batch_loss: 0.03863006830215454\n",
      "epoch: 5 / 20, batch: 2000 / 2440, batch_loss: 0.054879751056432724\n",
      "epoch: 5 / 20, batch: 2100 / 2440, batch_loss: 0.04007860645651817\n",
      "epoch: 5 / 20, batch: 2200 / 2440, batch_loss: 0.030957616865634918\n",
      "epoch: 5 / 20, batch: 2300 / 2440, batch_loss: 0.04104179888963699\n",
      "epoch: 5 / 20, batch: 2400 / 2440, batch_loss: 0.040289998054504395\n",
      "Epoch took 5 / 20,  took 3.7973315715789795 seconds\n",
      "epoch 5 / 20, validation_loss: 0.05091\n",
      "epoch: 6 / 20, batch: 0 / 2440, batch_loss: 0.049600519239902496\n",
      "epoch: 6 / 20, batch: 100 / 2440, batch_loss: 0.04420928657054901\n",
      "epoch: 6 / 20, batch: 200 / 2440, batch_loss: 0.03834342956542969\n",
      "epoch: 6 / 20, batch: 300 / 2440, batch_loss: 0.04654310643672943\n",
      "epoch: 6 / 20, batch: 400 / 2440, batch_loss: 0.03228059783577919\n",
      "epoch: 6 / 20, batch: 500 / 2440, batch_loss: 0.04667595028877258\n",
      "epoch: 6 / 20, batch: 600 / 2440, batch_loss: 0.037967499345541\n",
      "epoch: 6 / 20, batch: 700 / 2440, batch_loss: 0.03553421422839165\n",
      "epoch: 6 / 20, batch: 800 / 2440, batch_loss: 0.05953478813171387\n",
      "epoch: 6 / 20, batch: 900 / 2440, batch_loss: 0.05120183527469635\n",
      "epoch: 6 / 20, batch: 1000 / 2440, batch_loss: 0.06989821791648865\n",
      "epoch: 6 / 20, batch: 1100 / 2440, batch_loss: 0.045672208070755005\n",
      "epoch: 6 / 20, batch: 1200 / 2440, batch_loss: 0.032889775931835175\n",
      "epoch: 6 / 20, batch: 1300 / 2440, batch_loss: 0.0328468419611454\n",
      "epoch: 6 / 20, batch: 1400 / 2440, batch_loss: 0.03037184476852417\n",
      "epoch: 6 / 20, batch: 1500 / 2440, batch_loss: 0.04612214118242264\n",
      "epoch: 6 / 20, batch: 1600 / 2440, batch_loss: 0.0657576322555542\n",
      "epoch: 6 / 20, batch: 1700 / 2440, batch_loss: 0.03831800818443298\n",
      "epoch: 6 / 20, batch: 1800 / 2440, batch_loss: 0.02955693006515503\n",
      "epoch: 6 / 20, batch: 1900 / 2440, batch_loss: 0.041114866733551025\n",
      "epoch: 6 / 20, batch: 2000 / 2440, batch_loss: 0.025669096037745476\n",
      "epoch: 6 / 20, batch: 2100 / 2440, batch_loss: 0.032515812665224075\n",
      "epoch: 6 / 20, batch: 2200 / 2440, batch_loss: 0.0394020713865757\n",
      "epoch: 6 / 20, batch: 2300 / 2440, batch_loss: 0.05317234992980957\n",
      "epoch: 6 / 20, batch: 2400 / 2440, batch_loss: 0.059973377734422684\n",
      "Epoch took 6 / 20,  took 3.8066840171813965 seconds\n",
      "epoch 6 / 20, validation_loss: 0.04509\n",
      "epoch: 7 / 20, batch: 0 / 2440, batch_loss: 0.04944422096014023\n",
      "epoch: 7 / 20, batch: 100 / 2440, batch_loss: 0.0326726958155632\n",
      "epoch: 7 / 20, batch: 200 / 2440, batch_loss: 0.02890639193356037\n",
      "epoch: 7 / 20, batch: 300 / 2440, batch_loss: 0.028051769360899925\n",
      "epoch: 7 / 20, batch: 400 / 2440, batch_loss: 0.03323528543114662\n",
      "epoch: 7 / 20, batch: 500 / 2440, batch_loss: 0.03238992393016815\n",
      "epoch: 7 / 20, batch: 600 / 2440, batch_loss: 0.04166383668780327\n",
      "epoch: 7 / 20, batch: 700 / 2440, batch_loss: 0.05008497089147568\n",
      "epoch: 7 / 20, batch: 800 / 2440, batch_loss: 0.03277754783630371\n",
      "epoch: 7 / 20, batch: 900 / 2440, batch_loss: 0.03470487520098686\n",
      "epoch: 7 / 20, batch: 1000 / 2440, batch_loss: 0.04323090612888336\n",
      "epoch: 7 / 20, batch: 1100 / 2440, batch_loss: 0.030652744695544243\n",
      "epoch: 7 / 20, batch: 1200 / 2440, batch_loss: 0.03418680280447006\n",
      "epoch: 7 / 20, batch: 1300 / 2440, batch_loss: 0.03901415318250656\n",
      "epoch: 7 / 20, batch: 1400 / 2440, batch_loss: 0.0372593030333519\n",
      "epoch: 7 / 20, batch: 1500 / 2440, batch_loss: 0.032655980437994\n",
      "epoch: 7 / 20, batch: 1600 / 2440, batch_loss: 0.026790719479322433\n",
      "epoch: 7 / 20, batch: 1700 / 2440, batch_loss: 0.04849889129400253\n",
      "epoch: 7 / 20, batch: 1800 / 2440, batch_loss: 0.057230208069086075\n",
      "epoch: 7 / 20, batch: 1900 / 2440, batch_loss: 0.03678677976131439\n",
      "epoch: 7 / 20, batch: 2000 / 2440, batch_loss: 0.047493305057287216\n",
      "epoch: 7 / 20, batch: 2100 / 2440, batch_loss: 0.05061342939734459\n",
      "epoch: 7 / 20, batch: 2200 / 2440, batch_loss: 0.05488424748182297\n",
      "epoch: 7 / 20, batch: 2300 / 2440, batch_loss: 0.042903073132038116\n",
      "epoch: 7 / 20, batch: 2400 / 2440, batch_loss: 0.03848975896835327\n",
      "Epoch took 7 / 20,  took 3.776787519454956 seconds\n",
      "epoch 7 / 20, validation_loss: 0.04182\n",
      "epoch: 8 / 20, batch: 0 / 2440, batch_loss: 0.04804763197898865\n",
      "epoch: 8 / 20, batch: 100 / 2440, batch_loss: 0.041079599410295486\n",
      "epoch: 8 / 20, batch: 200 / 2440, batch_loss: 0.033491142094135284\n",
      "epoch: 8 / 20, batch: 300 / 2440, batch_loss: 0.047172315418720245\n",
      "epoch: 8 / 20, batch: 400 / 2440, batch_loss: 0.031809814274311066\n",
      "epoch: 8 / 20, batch: 500 / 2440, batch_loss: 0.052292924374341965\n",
      "epoch: 8 / 20, batch: 600 / 2440, batch_loss: 0.040862005203962326\n",
      "epoch: 8 / 20, batch: 700 / 2440, batch_loss: 0.04492181912064552\n",
      "epoch: 8 / 20, batch: 800 / 2440, batch_loss: 0.043378330767154694\n",
      "epoch: 8 / 20, batch: 900 / 2440, batch_loss: 0.033781737089157104\n",
      "epoch: 8 / 20, batch: 1000 / 2440, batch_loss: 0.05309645086526871\n",
      "epoch: 8 / 20, batch: 1100 / 2440, batch_loss: 0.03822539374232292\n",
      "epoch: 8 / 20, batch: 1200 / 2440, batch_loss: 0.028831275179982185\n",
      "epoch: 8 / 20, batch: 1300 / 2440, batch_loss: 0.03361521661281586\n",
      "epoch: 8 / 20, batch: 1400 / 2440, batch_loss: 0.03693133592605591\n",
      "epoch: 8 / 20, batch: 1500 / 2440, batch_loss: 0.04942905902862549\n",
      "epoch: 8 / 20, batch: 1600 / 2440, batch_loss: 0.03967609256505966\n",
      "epoch: 8 / 20, batch: 1700 / 2440, batch_loss: 0.052534446120262146\n",
      "epoch: 8 / 20, batch: 1800 / 2440, batch_loss: 0.03175407648086548\n",
      "epoch: 8 / 20, batch: 1900 / 2440, batch_loss: 0.025437094271183014\n",
      "epoch: 8 / 20, batch: 2000 / 2440, batch_loss: 0.04650438576936722\n",
      "epoch: 8 / 20, batch: 2100 / 2440, batch_loss: 0.03202113136649132\n",
      "epoch: 8 / 20, batch: 2200 / 2440, batch_loss: 0.0379919558763504\n",
      "epoch: 8 / 20, batch: 2300 / 2440, batch_loss: 0.039869122207164764\n",
      "epoch: 8 / 20, batch: 2400 / 2440, batch_loss: 0.03146914392709732\n",
      "Epoch took 8 / 20,  took 3.767800807952881 seconds\n",
      "epoch 8 / 20, validation_loss: 0.04724\n",
      "epoch: 9 / 20, batch: 0 / 2440, batch_loss: 0.04871995002031326\n",
      "epoch: 9 / 20, batch: 100 / 2440, batch_loss: 0.030478447675704956\n",
      "epoch: 9 / 20, batch: 200 / 2440, batch_loss: 0.03497273847460747\n",
      "epoch: 9 / 20, batch: 300 / 2440, batch_loss: 0.03601894527673721\n",
      "epoch: 9 / 20, batch: 400 / 2440, batch_loss: 0.03932525962591171\n",
      "epoch: 9 / 20, batch: 500 / 2440, batch_loss: 0.03290287405252457\n",
      "epoch: 9 / 20, batch: 600 / 2440, batch_loss: 0.0215008407831192\n",
      "epoch: 9 / 20, batch: 700 / 2440, batch_loss: 0.04798482358455658\n",
      "epoch: 9 / 20, batch: 800 / 2440, batch_loss: 0.04026862978935242\n",
      "epoch: 9 / 20, batch: 900 / 2440, batch_loss: 0.02946610562503338\n",
      "epoch: 9 / 20, batch: 1000 / 2440, batch_loss: 0.028596654534339905\n",
      "epoch: 9 / 20, batch: 1100 / 2440, batch_loss: 0.023426536470651627\n",
      "epoch: 9 / 20, batch: 1200 / 2440, batch_loss: 0.029500167816877365\n",
      "epoch: 9 / 20, batch: 1300 / 2440, batch_loss: 0.03612615540623665\n",
      "epoch: 9 / 20, batch: 1400 / 2440, batch_loss: 0.02643858641386032\n",
      "epoch: 9 / 20, batch: 1500 / 2440, batch_loss: 0.04684673622250557\n",
      "epoch: 9 / 20, batch: 1600 / 2440, batch_loss: 0.03504561632871628\n",
      "epoch: 9 / 20, batch: 1700 / 2440, batch_loss: 0.02864663675427437\n",
      "epoch: 9 / 20, batch: 1800 / 2440, batch_loss: 0.04048699513077736\n",
      "epoch: 9 / 20, batch: 1900 / 2440, batch_loss: 0.03871076554059982\n",
      "epoch: 9 / 20, batch: 2000 / 2440, batch_loss: 0.026965482160449028\n",
      "epoch: 9 / 20, batch: 2100 / 2440, batch_loss: 0.02263130620121956\n",
      "epoch: 9 / 20, batch: 2200 / 2440, batch_loss: 0.024637384340167046\n",
      "epoch: 9 / 20, batch: 2300 / 2440, batch_loss: 0.04596192389726639\n",
      "epoch: 9 / 20, batch: 2400 / 2440, batch_loss: 0.0301503986120224\n",
      "Epoch took 9 / 20,  took 3.799872875213623 seconds\n",
      "epoch 9 / 20, validation_loss: 0.03842\n",
      "epoch: 10 / 20, batch: 0 / 2440, batch_loss: 0.042499881237745285\n",
      "epoch: 10 / 20, batch: 100 / 2440, batch_loss: 0.026805639266967773\n",
      "epoch: 10 / 20, batch: 200 / 2440, batch_loss: 0.05220847949385643\n",
      "epoch: 10 / 20, batch: 300 / 2440, batch_loss: 0.03255221247673035\n",
      "epoch: 10 / 20, batch: 400 / 2440, batch_loss: 0.041427887976169586\n",
      "epoch: 10 / 20, batch: 500 / 2440, batch_loss: 0.04102826118469238\n",
      "epoch: 10 / 20, batch: 600 / 2440, batch_loss: 0.033548641949892044\n",
      "epoch: 10 / 20, batch: 700 / 2440, batch_loss: 0.031052857637405396\n",
      "epoch: 10 / 20, batch: 800 / 2440, batch_loss: 0.05130310356616974\n",
      "epoch: 10 / 20, batch: 900 / 2440, batch_loss: 0.052247609943151474\n",
      "epoch: 10 / 20, batch: 1000 / 2440, batch_loss: 0.019706740975379944\n",
      "epoch: 10 / 20, batch: 1100 / 2440, batch_loss: 0.04006282985210419\n",
      "epoch: 10 / 20, batch: 1200 / 2440, batch_loss: 0.03811828792095184\n",
      "epoch: 10 / 20, batch: 1300 / 2440, batch_loss: 0.03172377869486809\n",
      "epoch: 10 / 20, batch: 1400 / 2440, batch_loss: 0.02417786978185177\n",
      "epoch: 10 / 20, batch: 1500 / 2440, batch_loss: 0.0369187593460083\n",
      "epoch: 10 / 20, batch: 1600 / 2440, batch_loss: 0.03333206847310066\n",
      "epoch: 10 / 20, batch: 1700 / 2440, batch_loss: 0.032340604811906815\n",
      "epoch: 10 / 20, batch: 1800 / 2440, batch_loss: 0.04050077125430107\n",
      "epoch: 10 / 20, batch: 1900 / 2440, batch_loss: 0.041986845433712006\n",
      "epoch: 10 / 20, batch: 2000 / 2440, batch_loss: 0.03727627545595169\n",
      "epoch: 10 / 20, batch: 2100 / 2440, batch_loss: 0.03400562331080437\n",
      "epoch: 10 / 20, batch: 2200 / 2440, batch_loss: 0.04486893117427826\n",
      "epoch: 10 / 20, batch: 2300 / 2440, batch_loss: 0.03929939121007919\n",
      "epoch: 10 / 20, batch: 2400 / 2440, batch_loss: 0.04686536639928818\n",
      "Epoch took 10 / 20,  took 3.7557802200317383 seconds\n",
      "epoch 10 / 20, validation_loss: 0.03715\n",
      "epoch: 11 / 20, batch: 0 / 2440, batch_loss: 0.0469893217086792\n",
      "epoch: 11 / 20, batch: 100 / 2440, batch_loss: 0.034965284168720245\n",
      "epoch: 11 / 20, batch: 200 / 2440, batch_loss: 0.035857390612363815\n",
      "epoch: 11 / 20, batch: 300 / 2440, batch_loss: 0.03229210898280144\n",
      "epoch: 11 / 20, batch: 400 / 2440, batch_loss: 0.0481647253036499\n",
      "epoch: 11 / 20, batch: 500 / 2440, batch_loss: 0.02914625220000744\n",
      "epoch: 11 / 20, batch: 600 / 2440, batch_loss: 0.032419368624687195\n",
      "epoch: 11 / 20, batch: 700 / 2440, batch_loss: 0.039861105382442474\n",
      "epoch: 11 / 20, batch: 800 / 2440, batch_loss: 0.03774431347846985\n",
      "epoch: 11 / 20, batch: 900 / 2440, batch_loss: 0.03822353854775429\n",
      "epoch: 11 / 20, batch: 1000 / 2440, batch_loss: 0.03757942467927933\n",
      "epoch: 11 / 20, batch: 1100 / 2440, batch_loss: 0.029089247807860374\n",
      "epoch: 11 / 20, batch: 1200 / 2440, batch_loss: 0.03364301845431328\n",
      "epoch: 11 / 20, batch: 1300 / 2440, batch_loss: 0.029797133058309555\n",
      "epoch: 11 / 20, batch: 1400 / 2440, batch_loss: 0.03605165705084801\n",
      "epoch: 11 / 20, batch: 1500 / 2440, batch_loss: 0.027721336111426353\n",
      "epoch: 11 / 20, batch: 1600 / 2440, batch_loss: 0.03121432662010193\n",
      "epoch: 11 / 20, batch: 1700 / 2440, batch_loss: 0.02664719894528389\n",
      "epoch: 11 / 20, batch: 1800 / 2440, batch_loss: 0.034658033400774\n",
      "epoch: 11 / 20, batch: 1900 / 2440, batch_loss: 0.025839686393737793\n",
      "epoch: 11 / 20, batch: 2000 / 2440, batch_loss: 0.03442705422639847\n",
      "epoch: 11 / 20, batch: 2100 / 2440, batch_loss: 0.041235245764255524\n",
      "epoch: 11 / 20, batch: 2200 / 2440, batch_loss: 0.03513164818286896\n",
      "epoch: 11 / 20, batch: 2300 / 2440, batch_loss: 0.034954532980918884\n",
      "epoch: 11 / 20, batch: 2400 / 2440, batch_loss: 0.0348348468542099\n",
      "Epoch took 11 / 20,  took 3.7944674491882324 seconds\n",
      "epoch 11 / 20, validation_loss: 0.03799\n",
      "epoch: 12 / 20, batch: 0 / 2440, batch_loss: 0.04894629120826721\n",
      "epoch: 12 / 20, batch: 100 / 2440, batch_loss: 0.040492549538612366\n",
      "epoch: 12 / 20, batch: 200 / 2440, batch_loss: 0.02955630235373974\n",
      "epoch: 12 / 20, batch: 300 / 2440, batch_loss: 0.03216380253434181\n",
      "epoch: 12 / 20, batch: 400 / 2440, batch_loss: 0.043747060000896454\n",
      "epoch: 12 / 20, batch: 500 / 2440, batch_loss: 0.03451559692621231\n",
      "epoch: 12 / 20, batch: 600 / 2440, batch_loss: 0.03667508065700531\n",
      "epoch: 12 / 20, batch: 700 / 2440, batch_loss: 0.030844207853078842\n",
      "epoch: 12 / 20, batch: 800 / 2440, batch_loss: 0.04029500484466553\n",
      "epoch: 12 / 20, batch: 900 / 2440, batch_loss: 0.033438898622989655\n",
      "epoch: 12 / 20, batch: 1000 / 2440, batch_loss: 0.039861150085926056\n",
      "epoch: 12 / 20, batch: 1100 / 2440, batch_loss: 0.04575042799115181\n",
      "epoch: 12 / 20, batch: 1200 / 2440, batch_loss: 0.016101617366075516\n",
      "epoch: 12 / 20, batch: 1300 / 2440, batch_loss: 0.03319504112005234\n",
      "epoch: 12 / 20, batch: 1400 / 2440, batch_loss: 0.03771434724330902\n",
      "epoch: 12 / 20, batch: 1500 / 2440, batch_loss: 0.02632061019539833\n",
      "epoch: 12 / 20, batch: 1600 / 2440, batch_loss: 0.034060388803482056\n",
      "epoch: 12 / 20, batch: 1700 / 2440, batch_loss: 0.032366909086704254\n",
      "epoch: 12 / 20, batch: 1800 / 2440, batch_loss: 0.01812577247619629\n",
      "epoch: 12 / 20, batch: 1900 / 2440, batch_loss: 0.04738713428378105\n",
      "epoch: 12 / 20, batch: 2000 / 2440, batch_loss: 0.05877498537302017\n",
      "epoch: 12 / 20, batch: 2100 / 2440, batch_loss: 0.027764886617660522\n",
      "epoch: 12 / 20, batch: 2200 / 2440, batch_loss: 0.027601752430200577\n",
      "epoch: 12 / 20, batch: 2300 / 2440, batch_loss: 0.027481580153107643\n",
      "epoch: 12 / 20, batch: 2400 / 2440, batch_loss: 0.05140882357954979\n",
      "Epoch took 12 / 20,  took 3.7883894443511963 seconds\n",
      "epoch 12 / 20, validation_loss: 0.03727\n",
      "epoch: 13 / 20, batch: 0 / 2440, batch_loss: 0.03451637551188469\n",
      "epoch: 13 / 20, batch: 100 / 2440, batch_loss: 0.014483795501291752\n",
      "epoch: 13 / 20, batch: 200 / 2440, batch_loss: 0.030513012781739235\n",
      "epoch: 13 / 20, batch: 300 / 2440, batch_loss: 0.02268180623650551\n",
      "epoch: 13 / 20, batch: 400 / 2440, batch_loss: 0.032136719673871994\n",
      "epoch: 13 / 20, batch: 500 / 2440, batch_loss: 0.02936648763716221\n",
      "epoch: 13 / 20, batch: 600 / 2440, batch_loss: 0.032676368951797485\n",
      "epoch: 13 / 20, batch: 700 / 2440, batch_loss: 0.021279409527778625\n",
      "epoch: 13 / 20, batch: 800 / 2440, batch_loss: 0.027890227735042572\n",
      "epoch: 13 / 20, batch: 900 / 2440, batch_loss: 0.029240507632493973\n",
      "epoch: 13 / 20, batch: 1000 / 2440, batch_loss: 0.04501231014728546\n",
      "epoch: 13 / 20, batch: 1100 / 2440, batch_loss: 0.04826994985342026\n",
      "epoch: 13 / 20, batch: 1200 / 2440, batch_loss: 0.023019563406705856\n",
      "epoch: 13 / 20, batch: 1300 / 2440, batch_loss: 0.025064054876565933\n",
      "epoch: 13 / 20, batch: 1400 / 2440, batch_loss: 0.042331013828516006\n",
      "epoch: 13 / 20, batch: 1500 / 2440, batch_loss: 0.024131953716278076\n",
      "epoch: 13 / 20, batch: 1600 / 2440, batch_loss: 0.02529831975698471\n",
      "epoch: 13 / 20, batch: 1700 / 2440, batch_loss: 0.016694769263267517\n",
      "epoch: 13 / 20, batch: 1800 / 2440, batch_loss: 0.030881457030773163\n",
      "epoch: 13 / 20, batch: 1900 / 2440, batch_loss: 0.03252216428518295\n",
      "epoch: 13 / 20, batch: 2000 / 2440, batch_loss: 0.02312810719013214\n",
      "epoch: 13 / 20, batch: 2100 / 2440, batch_loss: 0.03289248049259186\n",
      "epoch: 13 / 20, batch: 2200 / 2440, batch_loss: 0.0358734130859375\n",
      "epoch: 13 / 20, batch: 2300 / 2440, batch_loss: 0.03797478228807449\n",
      "epoch: 13 / 20, batch: 2400 / 2440, batch_loss: 0.03524656593799591\n",
      "Epoch took 13 / 20,  took 3.7552502155303955 seconds\n",
      "epoch 13 / 20, validation_loss: 0.0374\n",
      "Epoch 00014: reducing learning rate of group 0 to 2.0000e-04.\n",
      "epoch: 14 / 20, batch: 0 / 2440, batch_loss: 0.02912684716284275\n",
      "epoch: 14 / 20, batch: 100 / 2440, batch_loss: 0.03049331344664097\n",
      "epoch: 14 / 20, batch: 200 / 2440, batch_loss: 0.01733759045600891\n",
      "epoch: 14 / 20, batch: 300 / 2440, batch_loss: 0.030962791293859482\n",
      "epoch: 14 / 20, batch: 400 / 2440, batch_loss: 0.028138073161244392\n",
      "epoch: 14 / 20, batch: 500 / 2440, batch_loss: 0.02338271029293537\n",
      "epoch: 14 / 20, batch: 600 / 2440, batch_loss: 0.03134305030107498\n",
      "epoch: 14 / 20, batch: 700 / 2440, batch_loss: 0.011347610503435135\n",
      "epoch: 14 / 20, batch: 800 / 2440, batch_loss: 0.027137892320752144\n",
      "epoch: 14 / 20, batch: 900 / 2440, batch_loss: 0.0439368337392807\n",
      "epoch: 14 / 20, batch: 1000 / 2440, batch_loss: 0.03849764168262482\n",
      "epoch: 14 / 20, batch: 1100 / 2440, batch_loss: 0.03778647258877754\n",
      "epoch: 14 / 20, batch: 1200 / 2440, batch_loss: 0.0169899333268404\n",
      "epoch: 14 / 20, batch: 1300 / 2440, batch_loss: 0.03525102138519287\n",
      "epoch: 14 / 20, batch: 1400 / 2440, batch_loss: 0.02446819469332695\n",
      "epoch: 14 / 20, batch: 1500 / 2440, batch_loss: 0.032972585409879684\n",
      "epoch: 14 / 20, batch: 1600 / 2440, batch_loss: 0.025884972885251045\n",
      "epoch: 14 / 20, batch: 1700 / 2440, batch_loss: 0.04392126947641373\n",
      "epoch: 14 / 20, batch: 1800 / 2440, batch_loss: 0.03282436728477478\n",
      "epoch: 14 / 20, batch: 1900 / 2440, batch_loss: 0.024776913225650787\n",
      "epoch: 14 / 20, batch: 2000 / 2440, batch_loss: 0.03007996641099453\n",
      "epoch: 14 / 20, batch: 2100 / 2440, batch_loss: 0.029154345393180847\n",
      "epoch: 14 / 20, batch: 2200 / 2440, batch_loss: 0.031661685556173325\n",
      "epoch: 14 / 20, batch: 2300 / 2440, batch_loss: 0.03704719990491867\n",
      "epoch: 14 / 20, batch: 2400 / 2440, batch_loss: 0.02390027604997158\n",
      "Epoch took 14 / 20,  took 3.756432294845581 seconds\n",
      "epoch 14 / 20, validation_loss: 0.03188\n",
      "epoch: 15 / 20, batch: 0 / 2440, batch_loss: 0.025300465524196625\n",
      "epoch: 15 / 20, batch: 100 / 2440, batch_loss: 0.02845795825123787\n",
      "epoch: 15 / 20, batch: 200 / 2440, batch_loss: 0.0445672832429409\n",
      "epoch: 15 / 20, batch: 300 / 2440, batch_loss: 0.03395172953605652\n",
      "epoch: 15 / 20, batch: 400 / 2440, batch_loss: 0.030964139848947525\n",
      "epoch: 15 / 20, batch: 500 / 2440, batch_loss: 0.019801460206508636\n",
      "epoch: 15 / 20, batch: 600 / 2440, batch_loss: 0.03182527422904968\n",
      "epoch: 15 / 20, batch: 700 / 2440, batch_loss: 0.028369560837745667\n",
      "epoch: 15 / 20, batch: 800 / 2440, batch_loss: 0.028375694528222084\n",
      "epoch: 15 / 20, batch: 900 / 2440, batch_loss: 0.026844147592782974\n",
      "epoch: 15 / 20, batch: 1000 / 2440, batch_loss: 0.02581721544265747\n",
      "epoch: 15 / 20, batch: 1100 / 2440, batch_loss: 0.01818961650133133\n",
      "epoch: 15 / 20, batch: 1200 / 2440, batch_loss: 0.0345473550260067\n",
      "epoch: 15 / 20, batch: 1300 / 2440, batch_loss: 0.041066255420446396\n",
      "epoch: 15 / 20, batch: 1400 / 2440, batch_loss: 0.030847802758216858\n",
      "epoch: 15 / 20, batch: 1500 / 2440, batch_loss: 0.03058014065027237\n",
      "epoch: 15 / 20, batch: 1600 / 2440, batch_loss: 0.02700755000114441\n",
      "epoch: 15 / 20, batch: 1700 / 2440, batch_loss: 0.031747929751873016\n",
      "epoch: 15 / 20, batch: 1800 / 2440, batch_loss: 0.027990953996777534\n",
      "epoch: 15 / 20, batch: 1900 / 2440, batch_loss: 0.04009086638689041\n",
      "epoch: 15 / 20, batch: 2000 / 2440, batch_loss: 0.035466909408569336\n",
      "epoch: 15 / 20, batch: 2100 / 2440, batch_loss: 0.018631882965564728\n",
      "epoch: 15 / 20, batch: 2200 / 2440, batch_loss: 0.03113594278693199\n",
      "epoch: 15 / 20, batch: 2300 / 2440, batch_loss: 0.03081318736076355\n",
      "epoch: 15 / 20, batch: 2400 / 2440, batch_loss: 0.03630124777555466\n",
      "Epoch took 15 / 20,  took 3.7581088542938232 seconds\n",
      "epoch 15 / 20, validation_loss: 0.03237\n",
      "epoch: 16 / 20, batch: 0 / 2440, batch_loss: 0.034413982182741165\n",
      "epoch: 16 / 20, batch: 100 / 2440, batch_loss: 0.03066079318523407\n",
      "epoch: 16 / 20, batch: 200 / 2440, batch_loss: 0.02420884743332863\n",
      "epoch: 16 / 20, batch: 300 / 2440, batch_loss: 0.03201422840356827\n",
      "epoch: 16 / 20, batch: 400 / 2440, batch_loss: 0.030901558697223663\n",
      "epoch: 16 / 20, batch: 500 / 2440, batch_loss: 0.04468545317649841\n",
      "epoch: 16 / 20, batch: 600 / 2440, batch_loss: 0.02931942604482174\n",
      "epoch: 16 / 20, batch: 700 / 2440, batch_loss: 0.02658260613679886\n",
      "epoch: 16 / 20, batch: 800 / 2440, batch_loss: 0.02823583036661148\n",
      "epoch: 16 / 20, batch: 900 / 2440, batch_loss: 0.023703118786215782\n",
      "epoch: 16 / 20, batch: 1000 / 2440, batch_loss: 0.03214241936802864\n",
      "epoch: 16 / 20, batch: 1100 / 2440, batch_loss: 0.05156975984573364\n",
      "epoch: 16 / 20, batch: 1200 / 2440, batch_loss: 0.022807100787758827\n",
      "epoch: 16 / 20, batch: 1300 / 2440, batch_loss: 0.047288283705711365\n",
      "epoch: 16 / 20, batch: 1400 / 2440, batch_loss: 0.028143642470240593\n",
      "epoch: 16 / 20, batch: 1500 / 2440, batch_loss: 0.04788147658109665\n",
      "epoch: 16 / 20, batch: 1600 / 2440, batch_loss: 0.025280442088842392\n",
      "epoch: 16 / 20, batch: 1700 / 2440, batch_loss: 0.019540870562195778\n",
      "epoch: 16 / 20, batch: 1800 / 2440, batch_loss: 0.02438924089074135\n",
      "epoch: 16 / 20, batch: 1900 / 2440, batch_loss: 0.028488364070653915\n",
      "epoch: 16 / 20, batch: 2000 / 2440, batch_loss: 0.03383028134703636\n",
      "epoch: 16 / 20, batch: 2100 / 2440, batch_loss: 0.03738880902528763\n",
      "epoch: 16 / 20, batch: 2200 / 2440, batch_loss: 0.015757083892822266\n",
      "epoch: 16 / 20, batch: 2300 / 2440, batch_loss: 0.018174875527620316\n",
      "epoch: 16 / 20, batch: 2400 / 2440, batch_loss: 0.03431476652622223\n",
      "Epoch took 16 / 20,  took 3.7295949459075928 seconds\n",
      "epoch 16 / 20, validation_loss: 0.03105\n",
      "epoch: 17 / 20, batch: 0 / 2440, batch_loss: 0.03706962242722511\n",
      "epoch: 17 / 20, batch: 100 / 2440, batch_loss: 0.01589660346508026\n",
      "epoch: 17 / 20, batch: 200 / 2440, batch_loss: 0.030706588178873062\n",
      "epoch: 17 / 20, batch: 300 / 2440, batch_loss: 0.026566457003355026\n",
      "epoch: 17 / 20, batch: 400 / 2440, batch_loss: 0.0503225103020668\n",
      "epoch: 17 / 20, batch: 500 / 2440, batch_loss: 0.019566327333450317\n",
      "epoch: 17 / 20, batch: 600 / 2440, batch_loss: 0.023256465792655945\n",
      "epoch: 17 / 20, batch: 700 / 2440, batch_loss: 0.02770373225212097\n",
      "epoch: 17 / 20, batch: 800 / 2440, batch_loss: 0.024565793573856354\n",
      "epoch: 17 / 20, batch: 900 / 2440, batch_loss: 0.030950156971812248\n",
      "epoch: 17 / 20, batch: 1000 / 2440, batch_loss: 0.02563222125172615\n",
      "epoch: 17 / 20, batch: 1100 / 2440, batch_loss: 0.022875800728797913\n",
      "epoch: 17 / 20, batch: 1200 / 2440, batch_loss: 0.02650560438632965\n",
      "epoch: 17 / 20, batch: 1300 / 2440, batch_loss: 0.034440141171216965\n",
      "epoch: 17 / 20, batch: 1400 / 2440, batch_loss: 0.0107997702434659\n",
      "epoch: 17 / 20, batch: 1500 / 2440, batch_loss: 0.02754710614681244\n",
      "epoch: 17 / 20, batch: 1600 / 2440, batch_loss: 0.04943695664405823\n",
      "epoch: 17 / 20, batch: 1700 / 2440, batch_loss: 0.021550139412283897\n",
      "epoch: 17 / 20, batch: 1800 / 2440, batch_loss: 0.03668507561087608\n",
      "epoch: 17 / 20, batch: 1900 / 2440, batch_loss: 0.027265135198831558\n",
      "epoch: 17 / 20, batch: 2000 / 2440, batch_loss: 0.014094723388552666\n",
      "epoch: 17 / 20, batch: 2100 / 2440, batch_loss: 0.027720702812075615\n",
      "epoch: 17 / 20, batch: 2200 / 2440, batch_loss: 0.036374080926179886\n",
      "epoch: 17 / 20, batch: 2300 / 2440, batch_loss: 0.016018817201256752\n",
      "epoch: 17 / 20, batch: 2400 / 2440, batch_loss: 0.027793480083346367\n",
      "Epoch took 17 / 20,  took 3.760645866394043 seconds\n",
      "epoch 17 / 20, validation_loss: 0.0319\n",
      "epoch: 18 / 20, batch: 0 / 2440, batch_loss: 0.0232686884701252\n",
      "epoch: 18 / 20, batch: 100 / 2440, batch_loss: 0.028107669204473495\n",
      "epoch: 18 / 20, batch: 200 / 2440, batch_loss: 0.03205304965376854\n",
      "epoch: 18 / 20, batch: 300 / 2440, batch_loss: 0.03089730441570282\n",
      "epoch: 18 / 20, batch: 400 / 2440, batch_loss: 0.03681700676679611\n",
      "epoch: 18 / 20, batch: 500 / 2440, batch_loss: 0.02964841015636921\n",
      "epoch: 18 / 20, batch: 600 / 2440, batch_loss: 0.02049633115530014\n",
      "epoch: 18 / 20, batch: 700 / 2440, batch_loss: 0.029551027342677116\n",
      "epoch: 18 / 20, batch: 800 / 2440, batch_loss: 0.012959256768226624\n",
      "epoch: 18 / 20, batch: 900 / 2440, batch_loss: 0.033139027655124664\n",
      "epoch: 18 / 20, batch: 1000 / 2440, batch_loss: 0.031766362488269806\n",
      "epoch: 18 / 20, batch: 1100 / 2440, batch_loss: 0.03502247482538223\n",
      "epoch: 18 / 20, batch: 1200 / 2440, batch_loss: 0.032909177243709564\n",
      "epoch: 18 / 20, batch: 1300 / 2440, batch_loss: 0.021809574216604233\n",
      "epoch: 18 / 20, batch: 1400 / 2440, batch_loss: 0.0160225760191679\n",
      "epoch: 18 / 20, batch: 1500 / 2440, batch_loss: 0.0329984687268734\n",
      "epoch: 18 / 20, batch: 1600 / 2440, batch_loss: 0.032634325325489044\n",
      "epoch: 18 / 20, batch: 1700 / 2440, batch_loss: 0.023431893438100815\n",
      "epoch: 18 / 20, batch: 1800 / 2440, batch_loss: 0.02625272423028946\n",
      "epoch: 18 / 20, batch: 1900 / 2440, batch_loss: 0.030048634856939316\n",
      "epoch: 18 / 20, batch: 2000 / 2440, batch_loss: 0.03744270280003548\n",
      "epoch: 18 / 20, batch: 2100 / 2440, batch_loss: 0.024950725957751274\n",
      "epoch: 18 / 20, batch: 2200 / 2440, batch_loss: 0.01986616477370262\n",
      "epoch: 18 / 20, batch: 2300 / 2440, batch_loss: 0.019226348027586937\n",
      "epoch: 18 / 20, batch: 2400 / 2440, batch_loss: 0.022732824087142944\n",
      "Epoch took 18 / 20,  took 3.7623002529144287 seconds\n",
      "epoch 18 / 20, validation_loss: 0.03208\n",
      "epoch: 19 / 20, batch: 0 / 2440, batch_loss: 0.03312704339623451\n",
      "epoch: 19 / 20, batch: 100 / 2440, batch_loss: 0.027692651376128197\n",
      "epoch: 19 / 20, batch: 200 / 2440, batch_loss: 0.02391444519162178\n",
      "epoch: 19 / 20, batch: 300 / 2440, batch_loss: 0.013308357447385788\n",
      "epoch: 19 / 20, batch: 400 / 2440, batch_loss: 0.036430880427360535\n",
      "epoch: 19 / 20, batch: 500 / 2440, batch_loss: 0.03417734056711197\n",
      "epoch: 19 / 20, batch: 600 / 2440, batch_loss: 0.022724254056811333\n",
      "epoch: 19 / 20, batch: 700 / 2440, batch_loss: 0.015281498432159424\n",
      "epoch: 19 / 20, batch: 800 / 2440, batch_loss: 0.01947794295847416\n",
      "epoch: 19 / 20, batch: 900 / 2440, batch_loss: 0.03553059324622154\n",
      "epoch: 19 / 20, batch: 1000 / 2440, batch_loss: 0.031144220381975174\n",
      "epoch: 19 / 20, batch: 1100 / 2440, batch_loss: 0.031925320625305176\n",
      "epoch: 19 / 20, batch: 1200 / 2440, batch_loss: 0.03935621678829193\n",
      "epoch: 19 / 20, batch: 1300 / 2440, batch_loss: 0.024541771039366722\n",
      "epoch: 19 / 20, batch: 1400 / 2440, batch_loss: 0.025223394855856895\n",
      "epoch: 19 / 20, batch: 1500 / 2440, batch_loss: 0.02361888997256756\n",
      "epoch: 19 / 20, batch: 1600 / 2440, batch_loss: 0.024057429283857346\n",
      "epoch: 19 / 20, batch: 1700 / 2440, batch_loss: 0.019100001081824303\n",
      "epoch: 19 / 20, batch: 1800 / 2440, batch_loss: 0.018965888768434525\n",
      "epoch: 19 / 20, batch: 1900 / 2440, batch_loss: 0.025223977863788605\n",
      "epoch: 19 / 20, batch: 2000 / 2440, batch_loss: 0.02716778591275215\n",
      "epoch: 19 / 20, batch: 2100 / 2440, batch_loss: 0.027809496968984604\n",
      "epoch: 19 / 20, batch: 2200 / 2440, batch_loss: 0.02907310053706169\n",
      "epoch: 19 / 20, batch: 2300 / 2440, batch_loss: 0.036710359156131744\n",
      "epoch: 19 / 20, batch: 2400 / 2440, batch_loss: 0.025797389447689056\n",
      "Epoch took 19 / 20,  took 3.7802512645721436 seconds\n",
      "epoch 19 / 20, validation_loss: 0.03005\n",
      "Saving training history\n",
      "Saving training history to: data/trained_model/torch//angle_lan_test_run_notebook_torch_training_history.csv\n",
      "Saving model state dict\n",
      "Saving model parameters to: data/trained_model/torch//angle_lan_test_run_notebook_train_state_dict_torch.pt\n",
      "Saving training config to: data/trained_model/torch//angle_lan_test_run_notebook_train_config.pickle\n",
      "Saving training data details to: data/trained_model/torch//angle_lan_test_run_notebook_data_details.pickle\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "Waiting for W&B process to finish... <strong style=\"color:green\">(success).</strong>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4bb27eb383b94a6dba1da71255555cd1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded (0.000 MB deduped)\\r'), FloatProgress(value=1.0, max…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<style>\n",
       "    table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
       "    .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
       "    .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
       "    </style>\n",
       "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>loss</td><td>█▅▃▃▃▂▃▂▂▃▂▄▂▂▂▂▁▂▂▂</td></tr><tr><td>val_loss</td><td>█▅▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>loss</td><td>0.02958</td></tr><tr><td>val_loss</td><td>0.03005</td></tr></table><br/></div></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run <strong style=\"color:#cdcd00\">wd_0.0_optim_adam_test_run_notebook</strong> at: <a href='https://wandb.ai/afengler/test_run_notebook/runs/a6ab4cv5' target=\"_blank\">https://wandb.ai/afengler/test_run_notebook/runs/a6ab4cv5</a><br/>Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Find logs at: <code>./wandb/run-20231007_160318-a6ab4cv5/logs</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "wandb uploaded\n",
      "Training finished successfully...\n"
     ]
    }
   ],
   "source": [
    "model_trainer.train_and_evaluate(\n",
    "    output_folder=\"data/trained_model/torch/\",\n",
    "    output_file_id=MODEL,\n",
    "    run_id=\"test_run_notebook\",\n",
    "    wandb_on=True,\n",
    "    wandb_project_id=\"test_run_notebook\",\n",
    "    save_data_details=True,\n",
    "    verbose=1,\n",
    "    save_all=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tanh\n",
      "tanh\n",
      "linear\n"
     ]
    }
   ],
   "source": [
    "# Load Network:\n",
    "infer_net = lanfactory.trainers.LoadTorchMLPInfer(\n",
    "    model_file_path=\"data/trained_model/torch/angle_lan_test_run_notebook_train_state_dict_torch.pt\",\n",
    "    network_config=\"data/torch_models/angle/angle_torch_network_config.pickle\",\n",
    "    input_dim=model_config[\"n_params\"] + 2,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test parameters:\n",
    "v, a, z, t, theta = 0.5, 1.5, 0.5, 0.3, 0.3\n",
    "\n",
    "# Comparison simulator run\n",
    "sim_out = ssms.basic_simulators.simulator.simulator(\n",
    "    model=MODEL, theta=[v, a, z, t, theta], n_samples=50000\n",
    ")\n",
    "\n",
    "# Make input matric\n",
    "input_mat = torch.zeros(2000, 7)\n",
    "input_mat[:, 0] = torch.ones(2000) * v\n",
    "input_mat[:, 1] = torch.ones(2000) * a\n",
    "input_mat[:, 2] = torch.ones(2000) * z\n",
    "input_mat[:, 3] = torch.ones(2000) * t\n",
    "input_mat[:, 4] = torch.ones(2000) * theta\n",
    "input_mat[:, 5] = torch.tensor(\n",
    "    np.concatenate(\n",
    "        [\n",
    "            np.linspace(5, 0, 1000).astype(np.float32),\n",
    "            np.linspace(0, 5, 1000).astype(np.float32),\n",
    "        ]\n",
    "    )\n",
    ")\n",
    "input_mat[:, 6] = torch.tensor(\n",
    "    np.concatenate([np.repeat(-1.0, 1000), np.repeat(1.0, 1000)]).astype(np.float32)\n",
    ")\n",
    "\n",
    "net_out = infer_net(input_mat.cuda())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([6.06134177e-04, 9.69814683e-04, 3.27312455e-03, 4.72784658e-03,\n",
       "        1.07891883e-02, 1.95175205e-02, 2.54576354e-02, 3.51557822e-02,\n",
       "        4.15808045e-02, 5.27336734e-02, 6.76445741e-02, 7.46757306e-02,\n",
       "        8.82531361e-02, 1.00860727e-01, 1.14801813e-01, 1.20378247e-01,\n",
       "        1.32137251e-01, 1.46320790e-01, 1.48139193e-01, 1.34440560e-01,\n",
       "        9.64965609e-02, 2.44878207e-02, 1.21226835e-04, 0.00000000e+00,\n",
       "        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 4.15808045e-02,\n",
       "        2.90217044e-01, 5.02000325e-01, 5.30124951e-01, 5.14486689e-01,\n",
       "        4.83937527e-01, 4.26476007e-01, 3.70105528e-01, 3.11916647e-01,\n",
       "        2.64153274e-01, 2.16996035e-01, 1.75900138e-01, 1.41714170e-01,\n",
       "        1.19408433e-01, 7.79488551e-02, 5.66129321e-02, 3.97624020e-02,\n",
       "        2.65486769e-02, 1.39410861e-02, 7.87974430e-03, 3.75803190e-03,\n",
       "        1.69717569e-03, 6.06134177e-04]),\n",
       " array([-4.07246923, -3.90748926, -3.74250929, -3.57752932, -3.41254934,\n",
       "        -3.24756937, -3.0825894 , -2.91760942, -2.75262945, -2.58764948,\n",
       "        -2.42266951, -2.25768953, -2.09270956, -1.92772959, -1.76274961,\n",
       "        -1.59776964, -1.43278967, -1.2678097 , -1.10282972, -0.93784975,\n",
       "        -0.77286978, -0.6078898 , -0.44290983, -0.27792986, -0.11294989,\n",
       "         0.05203009,  0.21701006,  0.38199003,  0.54697001,  0.71194998,\n",
       "         0.87692995,  1.04190992,  1.2068899 ,  1.37186987,  1.53684984,\n",
       "         1.70182981,  1.86680979,  2.03178976,  2.19676973,  2.36174971,\n",
       "         2.52672968,  2.69170965,  2.85668962,  3.0216696 ,  3.18664957,\n",
       "         3.35162954,  3.51660952,  3.68158949,  3.84656946,  4.01154943,\n",
       "         4.17652941]),\n",
       " [<matplotlib.patches.Polygon at 0x7f7ad4ec5390>])"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAABOYUlEQVR4nO3deXyU5b3//9fMJJlkspNAFhIIO7IGQRDrgpaKrdZqF2lPFcupWBX82pOe1tKeSndsteoppUVpqR6tP+mC2taKYhTUioIgq0DYQkJC9n2bSTLz+2MmGySQhCT3zNzv5+MxD5OZe2Y+RO7hnev63Ndl8Xg8HkREREQMYjW6ABERETE3hRERERExlMKIiIiIGEphRERERAylMCIiIiKGUhgRERERQymMiIiIiKEURkRERMRQIUYX0Btut5vCwkKio6OxWCxGlyMiIiK94PF4qK2tJTU1Fau15/GPgAgjhYWFpKenG12GiIiI9EN+fj5paWk9Ph4QYSQ6Ohrw/mFiYmIMrkZERER6o6amhvT09PZ/x3sSEGGkbWomJiZGYURERCTAXKjFQg2sIiIiYiiFERERETGUwoiIiIgYSmFEREREDKUwIiIiIoZSGBERERFDKYyIiIiIoRRGRERExFAKIyIiImIohRERERExlMKIiIiIGEphRERERAylMCIiIiKGCohde0VEDFOVDw3lPT/uSIC49KGrRyQIKYyIiPSkKh/WzoXmhp6PCXXA8h0KJCIXQWFERKQnDeXeIPL59ZA48dzHy3Jg0zLvcQojIv2mMCIiciGJEyE10+gqRIKWGlhFRETEUAojIiIiYiiFERERETGUwoiIiIgYSmFEREREDKWraUREesHj8fDByQo+zK3AERbCwkuSGGV0USJBQmFEROQC6lwtrHh6J1uPlLbf97N/HeKHc5q5w8C6RIKFwoiIyAX84KWDbC1KJCzEyqenJVNS42T7iXJe2JnPHXbwlB7B0t0TtVS8SK8ojIiIXEBOcS1xjhSe+/o8po2MBeCvu07zxF/LaPDYcbx4d/dP1FLxIr2iMCIi0oP8ygbaYsRjt81sDyIAX5ydRlXDVSx85RGSQ+v53e2XkhQd3vFkLRUv0mu6mkZEpAd/ej8PgMvHDuO6yUnnPP71K8eQPmYiu5tH88Odod4l49tu3e1lIyLdUhgREelGUXUT750oB+D2y0d3e4zFYuFHn5uKzWrh1QNF7DtdNYQVigQPhRERkW68sDOPVrcHgDEJkT0eNzk5hs/NTAXgt28dH5LaRIJNv8LI2rVrycjIIDw8nHnz5rFjx44ej3366aexWCxdbuHh4T0eLyLiD/6xt7DXx96zYBwAmw8WcbKsfrBKEglafQ4jGzduJCsri1WrVrF7925mzpzJokWLKCkp6fE5MTExnDlzpv126tSpiypaRGQwHSup5XhpPaHW3n1ETkyK5tpJwwHviIqI9E2fw8hjjz3GsmXLWLp0KVOmTGHdunU4HA42bNjQ43MsFgvJycntt6SkcxvBREQMU5UPhXvab7ve38pUy0k+nVzd65f48lzveqx/23UaV4t7cOoUCVJ9urTX5XKxa9cuVq5c2X6f1Wpl4cKFbN++vcfn1dXVMXr0aNxuN5deeik///nPmTp1ao/HO51OnE5n+/c1NTV9KVNEpPeq8mHtXGhuaL9rMbDYDpTjXSvEkXDBl7lu8ggSo+yU1TnJPlTMpy/8FBHx6dPISFlZGa2treeMbCQlJVFUVNTtcyZNmsSGDRt4+eWXee6553C73VxxxRWcPn26x/dZvXo1sbGx7bf0dF2jLyKDpKHcG0Q+vx7u3obr629xa8tqbnT+jPwvvdrrRctCbVa+ODsNgJf39L7fRESG4Gqa+fPns2TJEjIzM7nmmmvYtGkTw4cP58knn+zxOStXrqS6urr9lp+fP9hliojZJU6E1Ew+ah7NRy2jKY6cTNqU+X1asOyzM1MAeOtICY3NrYNVqUjQ6dM0TWJiIjabjeLi4i73FxcXk5yc3KvXCA0NZdasWRw7dqzHY+x2O3a7vS+liYgMiA9PVQIwb+wwLJZud5zp0ZSUGDISHOSWN7Ajt4JrBqNAkSDUp5GRsLAwZs+eTXZ2dvt9breb7Oxs5s+f36vXaG1tZf/+/aSkpPStUhGRIdC2cNms9Lg+P9disfCZ6d7PtnePlg1gVSLBrc/TNFlZWaxfv55nnnmGQ4cOce+991JfX8/SpUsBWLJkSZcG1x//+Me8/vrrnDhxgt27d3P77bdz6tQp7rrrroH7U4iIDJADBd6G+c770PRFWxjZ7RthEZEL6/NGeYsXL6a0tJSHHnqIoqIiMjMz2bx5c3tTa15eHtZO1+ZXVlaybNkyioqKiI+PZ/bs2bz33ntMmTJl4P4UIiIDoLzOSUFVIwBTU2P69RpTU2NIirHTVOsG20BWJxK8+rVr74oVK1ixYkW3j23durXL948//jiPP/54f95GRGRIHSj0joqMTYwkOjy0X69hsVhYMHEEB3YdHsjSRIKa9qYREfE5WlwLwOSU6It6nWsnDx+IckRMQ2FERMTneGkdAOOHR13U63xifCIhVu+VOIW+aR8R6ZnCiIiIz7ESbxgZN+Liwkh0eChTUrw9Jx+qkVXkgvrVMyIiEoyOl3p33B13kSMjAJmj4qAMSk/ug8LU7g9yJPRpUTWRYKUwIiICVDc1U1HvAmDs8MiLfr0p48fQsMvO10tWw1Oruz8o1NHr5eZFgpnCiIgIUFDp7e0YGReBI+ziPxonT5rCpz2PYXdV8euvZDIu8azRlrIc2LTMuzeOwoiYnMKIiAhwutK7a+9AjIqAd+O8tDET2XqklLeqUxk3Y+yAvK5IMFIDq4gIUFjdBEBGwsCEEYD5YxMAeP9ExYC9pkgwUhgREQFKa50ApMVHDNhrzh/nDSMfnCyn1e0ZsNcVCTYKIyIiQEmNN4yMHMAwMjU1lmh7CLVNLXzsW91VRM6lMCIiApTUeqdpRsYNXBixWS3MyYgH4MNTmqoR6YnCiIgItF/WO5AjIwBzMoYBsEuLn4n0SGFERARweyAsxEpipH1AX/fSUd6REYURkZ4pjIiI+KTGhmP17SkzUGamx2KzWjhT3aR9akR6oDAiIuIz0FM0AI6wkPZ9ajQ6ItI9hREREZ+BbF7tbPZoTdWInI/CiIiIT6rCiIghFEZERHySYsIH5XXbwsjHZ2pocLUMynuIBDKFERERnxHRA3slTZvUuAhSYsNpdXvYk181KO8hEsgURkREfIYPUhgBmDUqDoC9+dWD9h4igUphRERMze3p2DNmMMPIzLQ4APadrhq09xAJVAojImJqNU0dPRwJA7zgWWczfGFkr6ZpRM6hMCIiplbZ4F0GPjYilLCQwftInJ4Wi8UChdVN7fvgiIiXwoiImFpbGIl3hA7q+0TZQxg/PAqAfeobEelCYURETK2yrhmAeEfYoL/XzPQ4QH0jImdTGBERU6tsbBsZGYIwkhYLwJ7TGhkR6UxhRERMrbLeNzISObjTNNB1ZMSD5/wHi5iIwoiImFpFw9CNjExOjiHMZqWqoZmiajWxirRRGBERU6uq94aRuCEII2EhVi5J9e7gm1NSN+jvJxIoFEZExNRqm7zTNLERIUPyfm19IzlFCiMibRRGRMTU2hY9iw4f/J4R6FiJ9WhJ7ZC8n0ggUBgREdPyeDzUNHrDSMxQjYyke0dGjmmaRqSdwoiImFaDq5VmtxuAmCEaGRmbGEVkmA1ni3tI3k8kECiMiIhpVfiaVwHCQ4fm49BqtTDF18QqIl4KIyJiWlUNze1fW7AM2ftOGxk7ZO8lEggURkTEtNrWGBlq01IVRkQ6UxgREdOqMiiMTE/rCCNuj1ZiFVEYERHT6twzMpTGJkZiD/F+/BZoJVYRhRERMa/KTj0jQynEZmVsYiQAx3WJr4jCiIiYV6VBIyMA40ZEAQojIqAwIiImVmlQzwjAuOHeMKKVWEVgaJYcFBHxQ0aGkfG+kRFKc3AXfITVctalxY4EiEsf+sJEDKAwIiKmVVnfPISri3SVnpZGg8fOw5Y1sH7NuQeEOmD5DgUSMQWFERExrerGZuIMeu/QYaO5Z9g6iooKefCGSVw9YXjHg2U5sGkZNJQrjIgpKIyIiGnVGBhGAJJGTSD7jJ33GtK5OnWygZWIGEsNrCJiSq1uD7XOFkNrmO5bFv5AQbWhdYgYTWFEREyptsmYNUY6a1sW/kBhNR6txCompjAiIqZU0+gdFQkPMe5jcGJyFKE2C1UNzRRUNRpWh4jRFEZExJRqfCMjkXbjWufsITYmJkUDmqoRc1MYERFTqmk0PoxAp6maghpD6xAxksKIiJiSP4yMAExL6+gbETErhRERMaW2npEoo8NIagzgnaZRE6uYlcKIiJhS28hIlN1maB2XpMRgs1ooq3NRXOM0tBYRoyiMiIgp+UvPSHiojQm+fWr2q4lVTEphRERMqbotjIQZvxD1VF8Tq8KImJXCiIiYUk2Tt2ckMtzYaRqA6SO9fSMHFUbEpPoVRtauXUtGRgbh4eHMmzePHTt29Op5L7zwAhaLhVtuuaU/bysiMmDapmmi/GBkZHqaRkbE3PocRjZu3EhWVharVq1i9+7dzJw5k0WLFlFSUnLe5+Xm5vLf//3fXHXVVf0uVkRkoPjLpb3gbWK1WqCk1klJTZPR5YgMuT6Hkccee4xly5axdOlSpkyZwrp163A4HGzYsKHH57S2tvLVr36VH/3oR4wdO/aiChYRGQhtl/b6QxhxhIUwbriaWMW8+hRGXC4Xu3btYuHChR0vYLWycOFCtm/f3uPzfvzjHzNixAi+/vWv9+p9nE4nNTU1XW4iIgPJXy7tbdOxg68+78R8+hRGysrKaG1tJSkpqcv9SUlJFBUVdfucd999lz/84Q+sX7++1++zevVqYmNj22/p6el9KVNE5IJq/OhqGoBpI9U3IuY1qFfT1NbWcscdd7B+/XoSExN7/byVK1dSXV3dfsvPzx/EKkXEbFrdHupdrQA4/GCaBjqaWLVhnphRn87CxMREbDYbxcXFXe4vLi4mOTn5nOOPHz9Obm4un/3sZ9vvc7vd3jcOCeHIkSOMGzfunOfZ7XbsdntfShMR6bU6Z0v71xGh/jFNMyUlBosFimqaqGxoJt7ogkSGUJ9GRsLCwpg9ezbZ2dnt97ndbrKzs5k/f/45x0+ePJn9+/ezZ8+e9tvNN9/Mtddey549ezT9IiKGqPeFkVCbhbAQi8HVeEXaQxibGAnAsZI6g6sRGVp9Hp/MysrizjvvZM6cOcydO5cnnniC+vp6li5dCsCSJUsYOXIkq1evJjw8nGnTpnV5flxcHMA594uIDJW2MBJpD8GCf4QR8DaxHi+t53hpHZcZXYzIEOpzGFm8eDGlpaU89NBDFBUVkZmZyebNm9ubWvPy8rBatbCriPivWqd/7Nh7tmkjY3lpTyHHSmqNLkVkSPXrTFyxYgUrVqzo9rGtW7ee97lPP/10f95SRGTA1PtxGAFN04j5aAhDREyn8zSNP5ma6t2jprTOZXAlIkNLYURETKe2yT9HRqLDQ9ubWEXMRGFEREzHX6dpoGOqRsRMFEZExHTq2qdp/GONkc6mjYwxugSRIacwIiKmU+f0rr4aZQ81uJJzaWREzEhhRERMp2Oaxh9HRjrCSNtmfiLBTmFEREynbZomKtz/ekZiwkNJjQ0HdImvmIfCiIiYTp2fXtrbZtyIKACOlSqMiDkojIiI6fjz1TQA431h5LhGRsQkFEZExHTq/D2MDNfIiJiLwoiImI7fT9P4wkhRtZOqBq3GKsFPYURETKfOT1dgbRPdqbH2QEGNgZWIDA3/PBNFRAZRl54RPx54GG8poOjI+xCZ1vUBRwLEpRtTlMggUBgREVNxuz3Uu7yLnkX6axhxJNBsDed/w34LH/4WPjzr8VAHLN+hQCJBQ2FEREyl3tXS/nV0eAjUGlhMT+LS+ejm1/nRxndIiQ3n90vmdDxWlgOblkFDucKIBA2FERExlXrfUvA2qwV7iP+2zU2YcAkHPac5WAXV8VOJjfC/petFBor/nokiIoOg/UqaMBsWi8XganoWHxlGWnwEAAcLqg2uRmRwKYyIiKm0hZHocP8faZiW6t2n5kChwogEN4URETGV+vY1Rvxvk7yzTU/zhpG9pxVGJLgpjIiIqfj7gmedzUyLA2Df6SpD6xAZbAojImIq/r4vTWdtIyP5FY2U1zkNrkZk8CiMiIipNPjWGIkI9f9pmtiIUMYOjwRgn6ZqJIgpjIiIqTT6wogjzP/DCECmb6pmT36VoXWIDCaFERExlbaREUcATNMAzEyPA2Cv+kYkiCmMiIipNPhWYHUEwDQNdAoj+VV4PB5jixEZJAojImIqDQE2TXNJSjShNguVDc3kVzQaXY7IoFAYERFTaW9gDQuMaRp7iI0pKTEA7NFUjQQphRERMZXG5sBZ9KxN56kakWCkMCIiptK2UV4gXNrbRoufSbBTGBERU+m4tDcwpmmgY2Rkf0E1LW41sUrwURgREVNp8E3TBEoDK8DYxEii7SE0NbvJq2gwuhyRAacwIiKmEmhX0wBYrRZmpHuXhs8prjW4GpGBpzAiIqbS4Ay8aRro6BtRGJFgpDAiIqbStuhZRACNjEBH30hOkcKIBB+FERExlcbmwJumAcj0hZFT6hmRIKQwIiKm4Wpx09zqvRolMsCmaZJiwkmOCUcX00gwUhgREdNou6wXAm+aBmCmr4lVJNgojIiIabRd1htitRAWEngff7NGxRtdgsigCLyzUUSknzr2pQm8URGA2aM7wogHzddI8FAYERHTaJumCbR+kTbTR8YSYrUAUFzjNLgakYGjMCIiphGIC551Fh5qY+zwKAAO6xJfCSIKIyJiGvUBusZIZ5NTogE4fKbG4EpEBo7CiIiYRmOAj4wAXJLsCyMaGZEgojAiIqbREIA79p7tkpQYAE6U1bevJisS6BRGRMQ0Gl2Bt2Pv2YZH2QFodXvYd7ra4GpEBobCiIiYRn2AX9p7tl2nKo0uQWRAKIyIiGkE+tU0Z/soT2FEgoPCiIiYRts0TaCuM3K23XlVeDxa/EwCn8KIiJhGoK/A2lmozUpFvYvccu3iK4FPYURETCOYpmnGj/AufrZbfSMSBIJjrFJEpBca2hc9C/yPvqviyjl0poYzhxshZXzXBx0JEJduTGEi/RD4Z6SISC81NbsBiAgN4JERRwKEOvjciR/yOTtw1HfrLNQBy3cokEjAUBgREdNwtninaewhATxDHZcOy3dQUXqGOzbswGqBF+6+vKMptywHNi2DhnKFEQkYCiMiYhptIyPhgTwyAhCXzrC4dGrj68mraGCncxQLMkYYXZVIvwXwrwciIn3jbPGGkYAeGenksoxhAOzMrTC4EpGL068zcu3atWRkZBAeHs68efPYsWNHj8du2rSJOXPmEBcXR2RkJJmZmTz77LP9LlhEpL+czd5pmoAfGfGZOyYegJ0ndUWNBLY+h5GNGzeSlZXFqlWr2L17NzNnzmTRokWUlJR0e/ywYcP4/ve/z/bt29m3bx9Lly5l6dKlvPbaaxddvIhIXwTryMie01Xt/TAigajPZ+Rjjz3GsmXLWLp0KVOmTGHdunU4HA42bNjQ7fELFizg1ltv5ZJLLmHcuHE88MADzJgxg3ffffeiixcR6YumIBsZGZMYSWJUGK4WtzbNk4DWpzDicrnYtWsXCxcu7HgBq5WFCxeyffv2Cz7f4/GQnZ3NkSNHuPrqq3s8zul0UlNT0+UmInKxgm1kxGKxtI+O7DipvhEJXH06I8vKymhtbSUpKanL/UlJSRQVFfX4vOrqaqKioggLC+PGG29kzZo1fOpTn+rx+NWrVxMbG9t+S0/X5WkicvGCbWQE1MQqwWFIfj2Ijo5mz5497Ny5k5/97GdkZWWxdevWHo9fuXIl1dXV7bf8/PyhKFNEglhLq5sWt3dTuWAZGQGYO8YbRnblVtLq1qZ5Epj6tM5IYmIiNpuN4uLiLvcXFxeTnJzc4/OsVivjx3uXK87MzOTQoUOsXr2aBQsWdHu83W7Hbrf3pTQRkfNqm6IBsIcGTxi5JCWGKHsItc4WDhfVMNXogkT6oU9nZFhYGLNnzyY7O7v9PrfbTXZ2NvPnz+/167jdbpxOZ1/eWkTkonQJIyHBM01js1q4dHTbJb6aqpHA1OdfD7Kysli/fj3PPPMMhw4d4t5776W+vp6lS5cCsGTJElauXNl+/OrVq9myZQsnTpzg0KFD/OpXv+LZZ5/l9ttvH7g/hYjIBbT1i4TaLNisFoOrGVhzM7xhZIf6RiRA9Xk5+MWLF1NaWspDDz1EUVERmZmZbN68ub2pNS8vD6u1I+PU19dz3333cfr0aSIiIpg8eTLPPfccixcvHrg/hYjIBbSNjIQH0ahIm44rairxXJtAcEUtMYN+7U2zYsUKVqxY0e1jZzem/vSnP+WnP/1pf95GRGRgVOXDmTymWk4SFxIKhXs6HivLMaysgTIzPY4wm5WyOieFVU2MNLogkT7SRnkiEtyq8mHtXMY0N/CKHWgFnjrrmFAHOBIMKG5ghIfamJkey87cSg4W1iiMSMBRGBGR4NZQDs0NHLvyMR7IbiQtPoInb5/d9RhHAsQF9npGl2UMY2duJQcKqrne6GJE+khhRERMoTpqLAc9DXjsMZCaaXQ5A27e2AR+u/U4+wu0LLwEnuC52F5E5DxcbUvBB9EaI53NGR1PiNVCSa2WTZDAE5xnpYjIWVytwXs1DUCkPYQZabFGlyHSLwojImIKwT4yAjB/XOA24Yq5Be9ZKSLSiSuI1xlpM39sYvvXHrRPjQQOhRERMQUzjIzM9vWNAJypajK4GpHeC96zUkSkk2DvGQGICLMxOTkagH26qkYCiMKIiJhCswlGRgCmp8UBsO+0wogEjuA+K0VEfJy+kRF7SHB/7M0Y6b2iZl9BNR6P+kYkMAT3WSki4tM2MhIeGrzTNACTfNM0lfUujpfWG1yNSO8ojIiIKbTt2hvsIyOd/3zbT5QbWIlI7wX3WSki4tPcao6Rkc7eP64wIoFBYURETMFlkpGRzt4/Ua6+EQkI5jkrRcTUOtYZMcfIiD3ESnm9i5ziOqNLEbkg7dorIqbgMsnVNG0WDq/CecbNkY8sTJqZ2vGAIwHi0o0rTKQbCiMiYgouk1xNgyMBQh3cV/EL7rMDH/hubUIdsHyHAon4FYURETEF04yMxKXD8h0cPpHLt/6ylyi7jeeXXY7NYoGyHNi0DBrKFUbEryiMiIgpmGZkBCAunfEzR5L39wZqm1rY7x5DZnqc0VWJ9CjIf0UQEfEyzciIT4jNyhXjEgB492ipwdWInJ85zkoRMT1ni/cSV1OMjPhcOT4RgHePlRlcicj5KYyIiCk0t7YC5hkZAbhywnAAdp2qpMHVYnA1Ij0zz1kpIqbmajbfyEhGgoORcRE0t3rYcbLC6HJEeqQwIiKm0Ow2V88IgMVi6ZiqOaqpGvFf5jkrRUQwzwqsbT4xQX0j4v8URkTEVMJNNDIC8AnfFTWHi2qpbGg2uBqR7pnrrBQRU7NZLYTYzPWxlxBlZ0pKDAB7T1cZW4xID8x1VoqIqZltVKTNVb6pmo/yKg2uRKR75jwzRcSUzNYv0uYTvibWPflVxhYi0gOFERExDbOOjFyWMYwwm5WyOpfRpYh0y5xnpoiYkllHRiLCbMzJiDe6DJEeKYyIiGmYaY2Rs7VN1Yj4I/OemSJiOmYdGYGOJlaAFrfHwEpEzqUwIiKmYdaeEYCpqbFEh4cAkFNca3A1Il2Z98wUEdMx88iIzWohMz0OgN15VYbWInI2hRERMQ0zj4wAXDra28S665TWGxH/Yu4zU0RMxcwjIwCXpnvDyNGSWirqdZmv+A+FERExDbOPjCRGhQHg8cA7R0sNrkakg7nPTBExFXuoPvLavJ2jXXzFf+jMFBHTsIeYe5qms205pbh1ia/4CYURETGNcI2MAN6fQ1mdk0NFNUaXIgIojIiIiWhkxGvGyDhAUzXiPxRGRMQ0NDLiNdu3T822nBKDKxHx0pkpIqahkRGv2aPiAPgwt5I6Z4uxxYigMCIiJqKREa+U2AgyEhy0uD28d0xTNWI8nZkiYhoaGelw9cThALyt9UbED4QYXYCIyFDRyIhPWQ43JSayy3KSwkOFeOa2YMHifcyRAHHpxtYnpqMwIiKmYfqREUcChDpg0zLmAq/YASfwVKdjQh2wfIcCiQwphRERMQ3Tr8Aal+4NGg3lAHz/pf3sya/mG1eP5eaZqVCWA5uWeR9XGJEhZPIzU0TMxPQjI+ANGamZkJrJ6GlXcNAzhheLEr33JU40ujoxKYURETEN9Yx01dbEuv1EOU3NrQZXI2amM1NETEMjI11NSoomOSacpmY3H5ysMLocMTGFERExDY2MdGWxWLh2snd05K3DWo1VjKMzU0SCmoeOnWk1MnKuayeNAODNwyVdflYiQ6lfYWTt2rVkZGQQHh7OvHnz2LFjR4/Hrl+/nquuuor4+Hji4+NZuHDheY8XERlIrpZOYUQjI+f4xPhEwmxW8ioaOF3ZZHQ5YlJ9PjM3btxIVlYWq1atYvfu3cycOZNFixZRUtL9EN/WrVv5yle+wltvvcX27dtJT0/n+uuvp6Cg4KKLFxG5EFeru/3rcI2MnCPSHsK8scMA2JmrvhExRp/DyGOPPcayZctYunQpU6ZMYd26dTgcDjZs2NDt8X/605+47777yMzMZPLkyfz+97/H7XaTnZ190cWLiFxIsy+MWCwQarMYXI1/+uRk71TNjpPlBlciZtWnMOJyudi1axcLFy7seAGrlYULF7J9+/ZevUZDQwPNzc0MGzasx2OcTic1NTVdbiIi/eFq8YYRu82KxaIw0p3rJicBcOhMrcGViFn1KYyUlZXR2tpKUlJSl/uTkpIoKirq1Ws8+OCDpKamdgk0Z1u9ejWxsbHtt/R0rQQoIv3TNjISGqJ+kZ6MSnAwbngkLW41sIoxhvTsfPjhh3nhhRd48cUXCQ8P7/G4lStXUl1d3X7Lz88fwipFJJg4O42MSM+u803ViBihT3vTJCYmYrPZKC4u7nJ/cXExycnJ533uo48+ysMPP8wbb7zBjBkzznus3W7Hbrf3pTQRkW61TdNoZOT8rpucxHvver92ezxa90GGVJ/+voWFhTF79uwuzadtzajz58/v8Xm//OUv+clPfsLmzZuZM2dO/6sVEemj9p4RhZHzmpMRT6Tde7VRTnGdwdWI2fT57MzKymL9+vU888wzHDp0iHvvvZf6+nqWLl0KwJIlS1i5cmX78b/4xS/4wQ9+wIYNG8jIyKCoqIiioiLq6vSXXUQGn0s9I70SarNyaXo8AB/qEl8ZYn2apgFYvHgxpaWlPPTQQxQVFZGZmcnmzZvbm1rz8vKwWjtO+t/97ne4XC6++MUvdnmdVatW8cMf/vDiqhcRuQCXekZ67bKMYZAPO3IruN3oYsRU+hxGAFasWMGKFSu6fWzr1q1dvs/Nze3PW4iIDAiNjPTe7Ix4eAeOl9ZTXNNEUkzPFxqIDCSdnSIS1NpGRsI0MnJBcRGh7V9r4zwZSjo7RSSotYcRjYz0yRuHFEZk6OjsFJGg1jZNo5GRvnn3WCmNrlajyxCT0NkpIkGtbddejYz0XlKMnaZmN+8cLTW6FDEJnZ0iEtRcrd7f7tXA2nuXj00E4PWPiy9wpMjA0NkpIkGtbWREi5713uVjvRuZZh8qpsU3zSUymHR2ikhQa27xjYyoZ6TXpqTEEO8IpbKhmQ9PVRpdjpiAzk4RCWpOLQffZyFWC9dN9i5k+fpBTdXI4OvXomciIoFCV9P0Q1kOX0hN4PBHJ8k9UIhnthMLFnAkQFy60dVJEFIYEZGgpnVG+sCRAKEO2LSMK4BX7IATeMr3eKgDlu9QIJEBpzAiIkFNIyN9EJfuDRsN5QD85JWPef9EBV+dN5r/GNsIm5Z5H1MYkQGms1NEglqzRkb6Ji4dUjMhNZPJs67ioGcMf8qLg8SJRlcmQUxnp4gENY2M9N8nL0nCaoGDhTUU1zYZXY4EMZ2dIhLU1DPSf8Miw7gsw7vmyPvHKwyuRoKZzk6RAFRY1ciz23N54o0c/rX/TPs/uHIuhZGLs2hqMgD/Pl5mcCUSzNTAKhJAmlvdPLYlhz+8c7J9+gFgwogofn/nHEYnRBpYnX9SGLk4n56ezI//+TEfn6mBMKOrkWCls1MkQFQ3NnPnhh38butxXK1uLh0Vx21z0kiIDONoSR1feep9KupdRpfpd9QzcnFSYiOYPToej8foSiSYaWREJAA0uFpY+scd7M6rIjLMxqNfmskN05KxWCyU1Dax+Mn3OVlWz4N/28f6JXOMLtevuFrcYNPIyMW4cXoKf8szugoJZjo7RfxcS6ube57bze68KmLCQ/jzPfP59PQULBYLACOiw/ntVy8lxGphy8fFvHtUc/ttPB6PRkYGwKenJ7d/XVan0TcZeDo7RfzcI68d4e2cUhxhNp7+z7lMTY0955hLUmK4/fLR3uNfP4JHY+oANLd6cPt+FNoor/9SYiOYkhIDwHtqZJVBoLNTxI/9a/8Znnz7BACPfmkml46K7/HYFdeNJyzEyt78KnbnVQ1Rhf7N6duxFzRNc7GuHJ8IwDsaeZNBoLNTxE+dKK3j23/ZC8A3rh7LZ6annPf4xCg7t2SmAvDMe7mDXV5AcHa65DksxGJgJYHvE74w8vGZGoqqtQCaDCyFERE/1NLq5r/+vJd6VyuXjx3GtxdN6tXzvjrPO1Xz+sdF1DtbBrPEgNDU3DEyYkFh5GIkRnVc1/uv/WcMrESCkcKIiB9a+9Zx9uZ7G1YfX5xJSC/7HWakxTImMZKmZjevf1w0yFX6P6cWgxsUryiMyABTGBHxM3vzq/j1m0cB+Mkt00iJjej1cy0WC5/zTdW8vKdwUOoLJJ1HRmRgWCyw61Ql+RUNRpciQURhRMSPNDW38l8b99Dq9vDZmal8LnNkn1/jphne3pL3jpWbfqpGIyMDb8ZI79Vcf9+rsCsDR2FExI/89q1jnCirJynGzk8+N7VfrzFueBSjhjlwtbr59zFzX/mgkZGBt2DyCAA27T6tS8hlwCiMiPiJ3LJ61m3zXsb7o5unEufoxUYgVflQuKfLzXJmL18dVUkqZbx1pGQwS/Z7GhkZeFeMS8AeYuV4aT0HC2uMLkeChJaDF/EDHo+HVX8/iKvVzdUTh7fvlAp4A0dD+blPaiiDjXdA87lz998A7rDb+cqhX+PxTG9frdVsnBoZGXBRYSEsnJLEK/vO8OJHBUwbee4ifCJ9pTAi4gdeO1jEtpxSwmxWfnTz1I7wUJUPa+d2GzgACHXA7X8DR2KXu13Fh3G8/A1a6so4XFTLJb7VM81GIyOD49bMkbyy7wx/31vIyk9P7vXVXiI9URgRMViDq4Uf/+NjAL5xzVjGJEZ2erDcG0Q+vx4SJ577ZEcCxKWfc3fnCZ73T5SbNoyoZ2RwXD1xOPGOUEprnbx3vJyrJw43uiQJcAojIgZb8+YxCqubSIuP4L4F47s/KHEipGb26/W3Hy9n6SfG9L/AAKaRkUFQlkMYsGx8La/sP8PO7fVcHeULyj2EY5ELURgRMdCxkjp+/463aXXVZ6cSEWYb8Pf44GQFbrcHq9V8fSMaGRlAjgTvtOCmZQDcB9xnB04CT/mOCXXA8h0KJNJnCiMiBvE2rR6gudXDJyeP4FNTkgb8PSJCrVQ3NvPxmRpTNho6mzUyMmDi0r1Bw9dM7cHDsv/bxZnqJr51/USuS6jyBpWGcoUR6TN1HYkY5J/7zvDvY+XYQ6z88Ob+rSlyIVNTvQHk/RPdXI1jAk0tGhkZUHHp3unC1EwsqbOYPucaDnrG8NTR6O57mkR6SWFExAB1zhZ++oq3afW+BeNJH+YYlPeZkWbuMKKRkcH1xTlpWCzw/okKzlQ3Gl2OBDBN04gY4H/fyKG4xsnoBAffuGbsoL3PnMhSplrqqT2Zh7sgBGvn9UZM0GyoBtbBNTIugqsmDOftnFJe/7iYO40uSAKWwojIEDtSVMuGf+cC8MObpxIeOvBNq23NhqO3fZNX7L771p91jAmaDdXAOvgWz0nn7ZxSsg+VKIxIvymMiAwhj8fDD14+QKvbw6KpSVw7aUTPK6wClOX07406NRv+z8sH+CivinuuGcdnfZvoUZZjimZDjYwMvoVTRhDvCKW83gX2Cx8v0h2FEZEh9NKeAnacrCA81MoPbppy4RVWwTuC4Ujo+5vFpUNcOsmTIjl4KofN5Ul8tp9rlQQqjYwMPnuIjVtnpfHBe/0MziIojIgMmZqmZn72ymEA7r9uAmnxDijMOf8Kq3DRvR1zx3iDzAcnK/B4PKbap0YjI0Nj8WXpfPCe9+vKhmbijS1HApDCiMgQeez1HMrqnIwdHsmyq85qWr2IFVYvZEZaLGEhVsrqnJwsq2fs8KhBeR9/pJGRoTEpOZqJSdFQBW8eKeYLPSwkLNITXdorMgQ+Lqzh/7bnAvDjm6cRFjJ0p154qI3M9DgAdpysGLL39QcaGRk6i6Z6F+177UAxbrfH4Gok0CiMiAwyt9vbtOr2wI0zUrhyQuKFnzTA5o0ZBpgvjGhkZOhc49ssr6CqkfeOm3NdG+k/hRGRQfbX3afZdaoSR5iN/7nxEkNqmOsLIx+YLIy4NDIyZCI6XaL+7Pu5xhUiAUk9IyKDqLqhmYdfPUwqZXxrXgIp9UegvtMB/b10t48uHRWPzWqhoKqR05UNpA3JuxqvqbmVCKOLMKEtHxdzprqRlFj99KV3FEZEBtEjrx8mvL6Q7PBvE7HTCTu7Oai/l+72QaQ9hGkjY9mbX8XO3ArSBn5PPr+knpGhd0NSNRSfJDv7dW6/fHTHAyZY8Vf6T2FEZJDsO13Fnz7IY4qllgicPV++O0Qf0vPGDGNvfhU7TlZwq0nCiHpGhpBv1d/7q37J/XZgn+/WxgQr/kr/KYyIDIJWt4fvv3gAjweunTwcchnUy3d7Y27GMJ56+4S3b+TyOMPqGEoaGRlCvlV/m+vKWPr0TirrXTx4wySunjDcNCv+Sv8pjIgMhLOWdH9lbyHuwhPMDbdx95QYbxgx2GUZw7BY4ERpPZUNkUG/MFVzq5sWtwfMs8ab8eLSCY1L59J5kfw6+yhrD0dx9TWZRlclAUBhRORidbOk+83AzW37dPyLIekLuZBYRyiTkqI5XFTLwcJqrjS0msGnKRrj/MfcUax96xgfnKzgSFEtk4wuSPyewojIxWoo77Kk+8ObD/PO0TImJUXz6JdmYLVY/KZ5b+6YYb4wUmOCMOKdojHR6vd+Izk2nOunJPHqgSL++O+TPDzf6IrE32mdEZGBkjiRbXUjWZcTzWHLWO667VasI2d5+0T8IIhAx3oj+wuqDa5k8LWNjITZ9DFnhP+8cgwAmz4qoKqx2eBqxN/pLBUZIA3NrXxv034All6RwZTUGIMrOtfcDG8YyS2vv8CRga8tjNiHcOl96TBndDwz0mJxtbh59UCR0eWIn9NZKjJANrx7koKqRtKHRfBfn+phB16DjYgJZ0xiJB4TbB3S2DYyojBiCIvFwtd9oyP/3HfG4GrE3+ksFRkgbb/9/fILM4m0+287VtvoSLBr6xkJD7Fd4EgZLJ+ZnkJyTDhVDS6jSxE/168wsnbtWjIyMggPD2fevHns2LGjx2MPHjzIF77wBTIyMrBYLDzxxBP9rVXEL9W7Wtq/vnP+aOaPM/aqmQtp6xsJdm3TNKEaGTFMqM3Kkis6VmH1YIIhOemXPp+lGzduJCsri1WrVrF7925mzpzJokWLKCkp6fb4hoYGxo4dy8MPP0xycvJFFyzib556+wQAKbHhPPjpyQZXc2Gdw0hjEF/+2qieEb/wH3NHEe77f7DvdPA3Tkv/9Pksfeyxx1i2bBlLly5lypQprFu3DofDwYYNG7o9/rLLLuORRx7hy1/+Mna7vdtjRAJCVT4U7ulye+utLeQe/giABz45AUeY/07PtEmLj2B4VBgAh4tqDa5m8KiB1T/EOcK47pIRALy8p9DgasRf9emT0+VysWvXLlauXNl+n9VqZeHChWzfvn3AinI6nTidzvbva2pqBuy1Rfqlm4XNAK4Frg2DZms40yeMNaa2PrJYLEwbGQu5cLCgmllGFzRI2i/tVc+I4W7OHAk5sCO3gpNl9YxJjDS6JPEzffqVoaysjNbWVpKSuu6ylZSURFHRwF26tXr1amJjY9tv6en+sUaDmFjnhc3u3obr62/xQOwT3Oj8GQ8mrsGyIrA2AJuaGgvA/sLgHTbvaGDVyIjR0uMiAPB4YP07JwyuRvyRX56lK1eupLq6uv2Wn59vdEkiXokT8aTM5KEdIbxcPIKCiIl8844vETJs9IWf60cy071hxFV0mNoTH54z/URV4J9zamD1T3/ddZqS2iajyxA/06dpmsTERGw2G8XFxV3uLy4uHtDmVLvdrv4S8VvPvX+KF3bmY7HA44szSYmNMLqkPktJSaMRO4+HrIX/W3vuAUGw3bsaWP3PJSkxHCx088d/5/LgDf7f7C1Dp09naVhYGLNnzyY7O7v9PrfbTXZ2NvPna/MBCX77Cqr50T8+BuDBGyZz7aQRBlfUT3Hp/HHWn7nR+TN+OfopuHtbx+3z671TUp12IQ5EbdM0dvWM+I0vzk4D4Lntp6hp0hLx0qHPvzJkZWWxfv16nnnmGQ4dOsS9995LfX09S5cuBWDJkiVdGlxdLhd79uxhz549uFwuCgoK2LNnD8eOHRu4P4XIEHn41cO0uD3cPDOVb1wdGA2rPZmXOYODnjE8eyqO5qQZ3j10UjMh0T9Xj+2rjgZW7ZTnL+ZmxDNhRBS1zhb+9H6e0eWIH+lzGFm8eDGPPvooDz30EJmZmezZs4fNmze3N7Xm5eVx5kzH0r+FhYXMmjWLWbNmcebMGR599FFmzZrFXXfdNXB/CpFBVlHvXUGyurGZGWmx/OILM7AE+HawmenxJESGUdvUws6TFUaXM+A6Lu3VyIi/sFos3HPNOAD+8O7J9v9HIv2aTF2xYgWnTp3C6XTywQcfMG/evPbHtm7dytNPP93+fUZGBh6P55zb1q1bL7Z2kSFR09TMQ38/CHgXNtvwtcuICAv8f+BsVguf9K3/8Mr+4Ns7ROuM+KebM1NJjQ2nrM7Jpt0FRpcjfkJnqUibbhY1c+bt5uE/vICtPAeAn3xuKolRwdNc/dmZqYB3X53mVrfB1QwsbZTnn0JtVu66yjvF+eTbx2l1a4l46ePVNCJBq4dFzezAzwHCwB0SQUpKmhHVDZr5YxNIiAyjvN7Fe8fLuWbicKNLGjBqYPVfX56bzq/fPMqp8gZePXCGm2akGl2SGEy/MojAOYuaNf7nm6xM/A03On/GF9yr2X/j37Gu2BnQl7p2J8Rm5TPTUwD4x97gWqpbDaz+yxEWwp3zMwD47VvH8Xg0OmJ2CiMinSVOpC5hGne+6uL/Oz2M3NDxrPzPLzP9smuCLoi0uTnT+1vpaweKgqqhUA2s/u1rV2TgCLPx8Zkasg91v9GqmIfCiEgn9a4W7tywgx0nK4i2h/DsXfOYkzHswk8MYLNHxZMSG06ts4VtOaVGlzNgOqZp9DHnj+Ijw1jiGx353+yjGh0xOZ2lIp384KWD7DpVSUx4CM/dNY9LR8UbXdKgs1ot3DQj+KZq1MDq/5ZdNYaIUBv7C6rZeiR4grD0nc5SEaC2qQWAI8W1xDlCeX7Z5cxMjzO2qCF0o6+B8M3DJThbguOqmvbl4EP1MeevEqLs3DHfu6/TExodMTWdpWJ6lfUuvv/SfgBiwkN4/q7LmTYy1uCqhtbMtFhGxkXQ4Gpl16ngWACt3ukNmI5Q9Yz4jbKccy6fv3diHZeGnqI0/1hQTRNK3+jSXjG18jonX/39B9hK68EOP//8dMakxhhd1pCzWCzcOCOFp94+wdtHy7jC6IIuktvtocHlHRkJD9PHnOEcCd7NFzctO+eheGCTDRqsdh54PYZrJt4S8KsbS9/pLBVzqcpv3wCusrGZh17cj628gVkRJeCGMQmRBhdonM9M94aRnbkVAT9m2tDpqqAIjYwYLy7duwt0D5svVp8+QOy/llNYWMC7x8q4akLwrHcjvaMwIuZx1sJm8cBa8K5s5sb7m5sjwbj6DNY2VdNU7fb+TAJYg2+KxmoBu9YZ8Q9x6T1eHt95UvR/3zjKleMTNTpiMgojYh6+hc2qbljLg9ucnK5qJDEqjJ/fOp2RcRHeIBKka4n0hsVi4ZpJw9m7w+hKLl69b4omMiwEC/pHLVCE2qx8eKqSt4+WBdVqwHJhAT4YK9J3397axGuVyVTFTuEn99zOyCnzITXT1EGkzVXjE40uYUC0N6/aNUUTSG70rQb8yGuHcWvPGlNRGJHg0s1md2238lPeK2YKq5tIi4/ghbsvZ1SCw7ha/dAV4xKx+gYSSuucxhZzEdrCSKSaVwPKl+akERlm40BBDa8eKDK6HBlCOlMlePSw2V2bBKDBY8cRN4KnvjHfOzUjXcQ6Qhk/IhqqYP/paq6baHRF/dN2JU2kXR9xgSQuIpS7rhrL/2Yf5VdbjrBoahIhNv3ObAY6UyV4dN7sLrHjX9HCqka+9+J+SutcRMUnseYbnyU5NtzAQv3b1JQYqILDRbVcZ3Qx/VTXNk0TpmmaQHPXVWP4v+25nCit52+7T7P4slFGlyRDQJFTgk/iRG8PSGomJ0LH8/mXGthaO5KmxGmsuUdB5EImJkcBkFNca3Al/dfg8k3TaGQk4ESHh7L82vEAPPHG0aDavFF6pjAiQSunuJbFT71PUU0TE0ZE8cLd8xkRoyByIROTogE4WVYfsP8Q1Ds1TRPIbr98NCmx4ZypbuK5908ZXY4MAYURCUr7Tldx25PbKa11Mjk5mv/v7ssZHh3gi2cMkaQY78+pxe3hYGGNwdX0T0cDq6ZpAlF4qI0HPjkBgLVvHaO6sdngimSwKYxI0DlQWMN/rP+AqoZmZqbH8cLdl5MYpSDSW53X5ThcFKBhxNfA6tDVNAHri7PTGD8iisqGZn7z5lGjy5FBpjAiQeehlw9Q52zh8rHD+NNd84hzhBldUsDKKQrMvpHaJu9v0lHhCiOBKsRm5X9uvASAp9/L5WRZvcEVyWBSGJGgse2od8dPZ4ub6yaP4Omlc4lSz8BFORygYaRtWD82ItTgSqRPztrVd0F0IUsyqhjeWsrP/3XI6OpkEOmTWgKex+Phd9uO88prR7jGDldNSCTr9tmEhShrX6yc4lo8Hk/A7ROiMBJgzrOr74+B79rtLPz4Ed47lsEVQbJKsHSlMCKBp9POuy1uD7/bepzNB4sYbykA4DuLJmFVELloVgtUNjRTWusMuKuQahRGAsv5dvUty8GxaRnxllp+/M+PeeX/XYXNGljhWC5MYUQCy1mrrIYA9wP3t/WnhjqwRuo3p4GQFBPO/io4XlofcGFEIyMB6Dy7+gJE2W18UFTL8x+c4o75GUNXlwwJhREJLL5VVguu+zWr3mvmTHUT9hAr3140ifljE0y/8+5ASo2LgCrILa9n/rgEo8vpky5hxG1wMTIgbr88gw+2uvnla0dYNC2ZEdGBFZDl/DSWLQHp/i31vFGVQnXcFH5y71eZf+UntfPuAGvbuyc3wK5i8Hg81DR51xmJc2hkJFh8Znoy00fGUtvUwk//qWbWYKMwIgGj0dXK2q3HAO8VM1dPHM4/77+SqamxBlcWnFJ8y+YH2iWVdc4WWn3bz2uaJnjYLBZ+fut0rBb4+95C3vFdPSfBQWFEAsKBgmpuWvMO/9rv3Vb8tjlp/PFrl2kNkUGU6hsZCbQw0jZFExZiJTxUK7AGk+lpsSzx9Yv84KUDAbtdgZxLYUT8Wqvbw5PbjnPrb//N8dJ6hkV6w8ed8zPUUT/I2qZpTlU04PaNNASCqgY1rwazb10/kRHRdnLLG3h8S47R5cgAURgRv5VTXMsX173H6lcP09zq4fopSfzmP2YZXZZpDI+2E2qz4GpxU1jdaHQ5vdYWRuIURoJSdHgoP7t1OgBPvXOCD3MrDK5IBoLCiPgdZ0sr6//5Nt/+9f/hyv+Iy+x5rPukjScX2oitO2l0eaYRYrWQPswBQG5Zg8HV9EJVPhTuwXV6N1MtJ5kbnuddybNMvz0Hm09NSeKLs9PweOBbf9lLg6vF6JLkIunSXvErO05W8MTf3uT3tfexLNTZ8cC/fTfwrtToCKxLTQPVmIRITpTWc7K8nisn+PH6LZ3Wn7kOuM4OlABP+R7X35nAd1ao/OFlLZTlFJBTbufhVw/z489NM6gwGQgKI+IXCqsaWf3qYf6xt5CpliIcdicfzfklmZfO7bKLLKC1RIbQmMRIAE6W+nkTq2/9GT6/nt8fCuXFPQV8ftZIvn7lGO/j+jsTuHpYKj4KeBposNtZuP0RFkwaznWTk4yoUAaAwogYqqm5ld+/c4K1bx2nsbkViwVumJoMx2DWpfO8a4eIYTJ8YSS33M/DSJvEiez3eDjoCeOW5EsgdazRFcnF6uVS8Vl/3ssr/++q9sZrCSwKI2IIT1Ue2w8c5Q/vnqCo2slYYOrIGO6+eizjLfVwzOgKBWBs28hIAF3eW1rrnd4bHm2/wJESMC6wVPzEpCgOFjWz4vndbLx7vjbJDEAKIzLkPj50kLF/vpYrPE6uAGj7N6MceNH3teb4/cKY4d4wklfRQHOrm1Cb/3/IK4yYz4M3TCb7hWo+yqti9auHWPXZqUaXJH2kMCJD5lR5Pb987Qi5+9/jFbuTb7WuYGbmXL50WRoRIWctTqU5fr+QFB1ORKiNxuZWTlc2tveQ+LOimiYARiiMmEZyTDiPfimDu5/dxR//ncvk5GgWXzbK6LKkDxRGZNCV1TlZk32UP32QR4vbwzTfL9cP3nEzIybOM7Y4OS+r1UJGYiSHztRwsqzO78NIbVMLtb59adLiHQZXI0Pp+qnJ/L9PTuDX2Uf5/osHGDUsMuA2eDQz/x9zlYDV4GphTfZRFjyylWe2n6LF7eGaicNZ82XvwmUjovSbayBo6xs54e9X1ADFtd5RkcQoOxFhWgrebP5r4QQ+OzOVFreHe57bxbGSOqNLkl7SyIgMOGdLKy/syOc3bx1rn7+fPjKWlZ+ezBXjE70LUUnAyEj0jjAEQhNrcbU3jKQP0xUVZmSxWHjkizPIr2hgT34Vt//+A/5yz/z2xfvEfymMyIBpbnXz112nWZN9lMLqJlIp45NxLu64PIOrJ0RitZyGwtNaETPAjEmMAgLj8t6i2ibAQrqmaMyl02dKOPD0DWF8d1Mx+RUN/NdT5ay972aSYsKNq08uSGFELlpLq5uX9hTy6+yj5FV4lw2fEV3DJvd3CGlqgq14b53pahn/5/uAn0ItUy0nsRUVQKHvA91PG4yLqpqACI2MmEUPC6LFAesA7NDQaOc/11n41d03aQ0SP6YwIv3W6vbwyv4zPPFGTns/QWJUGPcuGM/toyoJ2dAEn18PiRPPfbKf/mMmnPMBPwV4xQ600HV59eU7/O7/YX5lAxDBuOFRRpciQ+F8C6IB5bn7SXh9BbWVxXzpd+/x7F3z9HfDTymMSJ81t7p56aMCfrfteHsIiXeE8o1rxrFk/mgcYSFQWO09OHGiVlENNN18wN/97C4Kqhr5yS1TuTSi1BtUGsr9MIx4dxceP0L/4JjGeRZEaxt7TYuP4GBFE7et2876O+dw6aj4oatPekVhRHqtqbmVjTvzeertExRUNZJKGfPCG7ll1kg+l5mAI7QWyg54D1ZfSGA76wM+NM3Nwcoz7GgaxaXp/vtBXtXQDKDffqWLX3xhBgX/auJAQQ1ffvJ9fnrrNG6b419B2uwURuSCqhpc/OmDPP7475OU1bkAmBZVw4vu7xDqboKP8N7Opr6QoDElNYZX9p/h48IamGC58BMMNDIugki7PtqkQ1xEKBvvnkXWn/fw2sFivvPXfew/Xc33b7yE8FBdAu4PdMZKj46V1PLHf+fyt92naWp2A97hzm9cM47bUssJPV9PCKgvJIhMHxkLwK5TlXiIP3sfZb+iKRrpTqQ9hN99dTZr3jzG42/k8Oz7p9h+opz//XImU1NjjS7P9BRGpAu328O2o6X88d+5vJ1T2n7/5ORoll01lpszU737kxRWeh9QT4gpzB4dT6jNQkFVI0XVEaQYXdB5TExSGJHuWa0WHlg4gZnpsXz7r/s4VlLHLWv/zd1Xj2XFtRO0UJ6BFEYEgMKqRv7y4Wn+siuf074mQIsFPnVJEks/MYbLxw7DYvHn34dlMEXaQ5iVHs+O3Ar25Ff7dRiZNlK/5cpZzuphWxANW74cy8PbSnghx8Pat47z0keF/OCmKSyamqTPOgMojJhYU3Mr2YdK+POH+bx9tBSPx3t/dHgIt81J5875GYxK0OJR4nXF+AR25FawK6+CTxtdzFla3J72DzOFEWnXwzok4F2LZHWog09/4RW+l11FQVUj9zy3i5lpsWRdP4mrJyQqlAwhhRGTcba08nZOGf/cV8gbHxdT72ptf2zemGF8eW46N0xN6RiurMrv/hp+XS1jOtdPSeaJN47y4ckqCDW6mq5OVzaSATjCbIxJ8O/N/GQInW8dkrIcLJuWcc1IK1uyrua3bx3nD++eZO/pau7csIMZabF87YoMbpyRgv3sXcVlwCmMmEBlvYttOaW8daSENw+XtO9qCt4rD27OTOW2Oenn7shalQ9r50JzQ/cvrKtlTOWSlGgmJUXTXOI2upRz5JTUkoF3Uz+rVb/NSifnWYekjSMshP9eNImvfSKDdVuP8+z7p9h3upqsP+/l5/86xBcuTeOWWSO5JCVmiIo2H4WRINTU3MpHeVXsOFnBtpwSPsqvap+CAUiKsfOZ6SncNCOVWelxPX94N5R7g4hWURW8m5B9cXYaL726DwC3x+M3237vza/ietBVEdJ3nUZ5E4H/uRTun5LI5gNFPL+/nr210Tz59gmefPsEk5Ki+fT0ZBZMGsGMkbEKvgNIYSTAeTweCqoaOVhYw+68SnaerGB/QTXNrZ4ux12SEsO1k4Zz7eQRzB4V37eTSFfMiM/iuelsedM7ZP3+iQquGGlwQXivANuTXwVA5qg4Q2uRAHKefpJYYDFwW6iDbZ//Jy8cgTcPl3CkuJYjxbU88cZREiLDuGpCInPHJDB7dDwTRkQpnFyEfoWRtWvX8sgjj1BUVMTMmTNZs2YNc+fO7fH4v/zlL/zgBz8gNzeXCRMm8Itf/ILPfOYz/S7ajDweD6W1TnLLG8gtqyenuJaDhTV8fKaG6sbmc45PirFzWcYwPjE+kQWThpMS28MGUT31hID6QuQcMeGh3DQ9FfbDM9tzufTyVsMXjXr7aKl35VU7XJKsYXTppQvsa9PWU7IgzcaCuZlUNzbz2sEi3jpcwjtHyyivd/HSnkJe2lMIeBv/M9PjmJISw8SkaCYlRzN+RJTh50eg6HMY2bhxI1lZWaxbt4558+bxxBNPsGjRIo4cOcKIESPOOf69997jK1/5CqtXr+amm27i+eef55ZbbmH37t1MmzZtQP4Qga7V7aGmsZnSOifFNU0U13j/W1rrpKi6iVMVDZwqr6ehU7NpZyFWCxOSopk+MobLMoYxb0wC6cMiOjrBq/KhsJsTrqEMNt7Rc08IqC9EznHrrJGw39s0+qN/fMzPb51m2FUHHo+HdduOt38fatNvptIHvegnafulLBa4LRVuS7XRvGAE+ytsvFVkZ9epSvbkV1Hb1MI7R8t452hZ+1OtFhgZH0F6vMN7GxZB+jAHI+MiSIyykxhtJzLMpqt2AIvH4/Fc+LAO8+bN47LLLuM3v/kNAG63m/T0dO6//36++93vnnP84sWLqa+v55///Gf7fZdffjmZmZmsW7euV+9ZU1NDbGws1dXVxMQM7m8+Ho+HFreHVrfvv60emt3uLt+3+L5vbvUe52ptpdHlpqm5lcbmVprab53vc1PvbKG6sZnqxmZqmprbv65zttCb/wttf7EzEiIZNzyKKakxTEmJYUJSVM/d3r1pQl38LDgSu39cfSFytsI98NQ13OT6GQfcY7h20nCyPjWJaSNjhvRD1ePxsObNYzy2JYdZIad4MWQl3L1NU4oyMPrw2dni9pBb7h2xPlXu/eVxT7mNw41xF3yb8FCrN5hE2UmMCiM6PJTo8BDfLbTLf8NDbISHWrG3/TfURnhIx39DbP7SxdWht/9+92lkxOVysWvXLlauXNl+n9VqZeHChWzfvr3b52zfvp2srKwu9y1atIiXXnqpx/dxOp04nc7276urvTvA1tTU9KXcC1r2fx+yv6DKFzSgtdWNu0/RbGDFRoQwPNrO8OhwhkfZGRFtZ3h0GGnDHGTY60i11fp+82sGKr23OnDWgbOnFy0/BnX18NlfQ8L4cx+PGAZxaecvbIB/7hLgauvA6SFrRjO/2nWYgv0f863924iLCCPeEUpYiNW7Yl5/eTy4PeABPHi/8Pjuw+Px3Q/1zhbK611MBO6eaqHmY4+3Nv19lYFgjYU7sqGx4tzHGsph093wh8+33zXCd7vS970nJILqm39LoctBSU0TxXVOSn0j3kcbHBxtcNDoctPghLy6OvIGomQL2GxWbBawWS3YLBbvf60WrJ2+DrFasPoet1i8zekWC/zslmlMGuCpzrZ/ty807tGnMFJWVkZraytJSUld7k9KSuLw4cPdPqeoqKjb44uKinp8n9WrV/OjH/3onPvT04P7N/T8wXzxh5cO5quLKWVd+JAhsqXti4evMrIMkU5q4adfNLqIPtny4OC9dm1tLbGxPV/t5pdX06xcubLLaIrb7aaiooKEhATTz63V1NSQnp5Ofn7+oE9ZmZ1+1kNDP+ehoZ/z0NDPuSuPx0NtbS2pqannPa5PYSQxMRGbzUZxcXGX+4uLi0lOTu72OcnJyX06HsBut2O327vcFxcX15dSg15MTIz+og8R/ayHhn7OQ0M/56Ghn3OH842ItOlTt0tYWBizZ88mOzu7/T632012djbz58/v9jnz58/vcjzAli1bejxeREREzKXP0zRZWVnceeedzJkzh7lz5/LEE09QX1/P0qXenoQlS5YwcuRIVq9eDcADDzzANddcw69+9StuvPFGXnjhBT788EOeeuqpgf2TiIiISEDqcxhZvHgxpaWlPPTQQxQVFZGZmcnmzZvbm1Tz8vKwWjsGXK644gqef/55/ud//ofvfe97TJgwgZdeeklrjPST3W5n1apV50xjycDTz3po6Oc8NPRzHhr6OfdPn9cZERERERlI/rdCioiIiJiKwoiIiIgYSmFEREREDKUwIiIiIoZSGAkSTqeTzMxMLBYLe/bsMbqcoJKbm8vXv/51xowZQ0REBOPGjWPVqlW4XC6jSwt4a9euJSMjg/DwcObNm8eOHTuMLinorF69mssuu4zo6GhGjBjBLbfcwpEjR4wuK6g9/PDDWCwWvvnNbxpdSsBQGAkS3/nOdy643K70z+HDh3G73Tz55JMcPHiQxx9/nHXr1vG9733P6NIC2saNG8nKymLVqlXs3r2bmTNnsmjRIkpKSowuLahs27aN5cuX8/7777Nlyxaam5u5/vrrqa+vN7q0oLRz506efPJJZsyYYXQpgcUjAe9f//qXZ/LkyZ6DBw96AM9HH31kdElB75e//KVnzJgxRpcR0ObOnetZvnx5+/etra2e1NRUz+rVqw2sKviVlJR4AM+2bduMLiXo1NbWeiZMmODZsmWL55prrvE88MADRpcUMDQyEuCKi4tZtmwZzz77LA6Hw+hyTKO6upphw4YZXUbAcrlc7Nq1i4ULF7bfZ7VaWbhwIdu3bzewsuBXXV0NoL+/g2D58uXceOONXf5eS+/45a690jsej4evfe1r3HPPPcyZM4fc3FyjSzKFY8eOsWbNGh599FGjSwlYZWVltLa2tq/c3CYpKYnDhw8bVFXwc7vdfPOb3+QTn/iEVsEeYC+88AK7d+9m586dRpcSkDQy4oe++93vYrFYzns7fPgwa9asoba2lpUrVxpdckDq7c+5s4KCAm644Qa+9KUvsWzZMoMqF+mf5cuXc+DAAV544QWjSwkq+fn5PPDAA/zpT38iPDzc6HICkpaD90OlpaWUl5ef95ixY8dy22238Y9//AOLxdJ+f2trKzabja9+9as888wzg11qQOvtzzksLAyAwsJCFixYwOWXX87TTz/dZQ8m6RuXy4XD4eCvf/0rt9xyS/v9d955J1VVVbz88svGFRekVqxYwcsvv8zbb7/NmDFjjC4nqLz00kvceuut2Gy29vtaW1uxWCxYrVacTmeXx+RcCiMBLC8vj5qamvbvCwsLWbRoEX/961+ZN28eaWlpBlYXXAoKCrj22muZPXs2zz33nD5YBsC8efOYO3cua9asAbxTCKNGjWLFihV897vfNbi64OHxeLj//vt58cUX2bp1KxMmTDC6pKBTW1vLqVOnuty3dOlSJk+ezIMPPqgpsV5Qz0gAGzVqVJfvo6KiABg3bpyCyAAqKChgwYIFjB49mkcffZTS0tL2x5KTkw2sLLBlZWVx5513MmfOHObOncsTTzxBfX09S5cuNbq0oLJ8+XKef/55Xn75ZaKjoykqKgIgNjaWiIgIg6sLDtHR0ecEjsjISBISEhREeklhROQCtmzZwrFjxzh27Ng5IU8Di/23ePFiSktLeeihhygqKiIzM5PNmzef09QqF+d3v/sdAAsWLOhy/x//+Ee+9rWvDX1BIt3QNI2IiIgYSh14IiIiYiiFERERETGUwoiIiIgYSmFEREREDKUwIiIiIoZSGBERERFDKYyIiIiIoRRGRERExFAKIyIiImIohRERERExlMKIiIiIGEphRERERAz1/wMI9Jdvlj3zsAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "plt.plot(np.linspace(-5, 5, 2000).astype(np.float32), np.exp(net_out.cpu()))\n",
    "\n",
    "plt.hist(\n",
    "    sim_out[\"rts\"] * sim_out[\"choices\"],\n",
    "    bins=50,\n",
    "    histtype=\"step\",\n",
    "    fill=None,\n",
    "    density=True,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lan_pipe",
   "language": "python",
   "name": "lan_pipe"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
